mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-13 20:29:55 +00:00
Compare commits
36 Commits
feature/ke
...
v0.21.11
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6dee89379b | ||
|
|
76db4f801a | ||
|
|
6c2ed4b4f2 | ||
|
|
2541c78dd0 | ||
|
|
97b6e79809 | ||
|
|
6ad3847615 | ||
|
|
a4d830ef83 | ||
|
|
9e540cd5b4 | ||
|
|
3027d8f27e | ||
|
|
e69ec6ab6a | ||
|
|
7ddde41c92 | ||
|
|
7ebe58f20a | ||
|
|
9c2c0e7934 | ||
|
|
c6af1037d9 | ||
|
|
5cb9a126f1 | ||
|
|
f40951cdf5 | ||
|
|
6e264d9de7 | ||
|
|
42db9773f4 | ||
|
|
bb9f6f6d0a | ||
|
|
829ce6573e | ||
|
|
a366d9e208 | ||
|
|
e074c24487 | ||
|
|
54fe05f6d8 | ||
|
|
33a155d9aa | ||
|
|
51878659f8 | ||
|
|
c000c05435 | ||
|
|
b39ffef22c | ||
|
|
d96f882acb | ||
|
|
d409219b51 | ||
|
|
8b619a8224 | ||
|
|
ed075bc9b9 | ||
|
|
8eb098d6fd | ||
|
|
68a8687c80 | ||
|
|
f7d97b02fd | ||
|
|
2691e729cd | ||
|
|
b524a9d49d |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -116,7 +116,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
|
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||||
- name: Install rsrc
|
- name: Install rsrc
|
||||||
run: go install github.com/akavel/rsrc@v0.10.2
|
run: go install github.com/akavel/rsrc@v0.10.2
|
||||||
- name: Generate windows rsrc
|
- name: Generate windows rsrc
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ jobs:
|
|||||||
CI_NETBIRD_MGMT_IDP: "none"
|
CI_NETBIRD_MGMT_IDP: "none"
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||||
|
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||||
|
|
||||||
- name: check values
|
- name: check values
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ builds:
|
|||||||
- amd64
|
- amd64
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
tags:
|
||||||
|
- legacy_appindicator
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
- id: netbird-ui-windows
|
- id: netbird-ui-windows
|
||||||
@@ -55,9 +57,6 @@ nfpms:
|
|||||||
- src: client/ui/disconnected.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- libayatana-appindicator3-1
|
|
||||||
- libgtk-3-dev
|
|
||||||
- libappindicator3-dev
|
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
@@ -75,9 +74,6 @@ nfpms:
|
|||||||
- src: client/ui/disconnected.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- libayatana-appindicator3-1
|
|
||||||
- libgtk-3-dev
|
|
||||||
- libappindicator3-dev
|
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
uploads:
|
uploads:
|
||||||
|
|||||||
@@ -70,9 +70,9 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
|||||||
|
|
||||||
### Start using NetBird
|
### Start using NetBird
|
||||||
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
||||||
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
|
- See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started).
|
||||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
|
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide).
|
||||||
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
|
- Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms.
|
||||||
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
||||||
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
|||||||
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
|
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||||
|
|
||||||
### Roadmap
|
### Roadmap
|
||||||
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -35,6 +36,11 @@ type RouteListener interface {
|
|||||||
routemanager.RouteListener
|
routemanager.RouteListener
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DnsReadyListener export internal dns ReadyListener for mobile
|
||||||
|
type DnsReadyListener interface {
|
||||||
|
dns.ReadyListener
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
@@ -49,6 +55,7 @@ type Client struct {
|
|||||||
ctxCancelLock *sync.Mutex
|
ctxCancelLock *sync.Mutex
|
||||||
deviceName string
|
deviceName string
|
||||||
routeListener routemanager.RouteListener
|
routeListener routemanager.RouteListener
|
||||||
|
onHostDnsFn func([]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -65,7 +72,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener) error {
|
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
@@ -90,7 +97,8 @@ func (c *Client) Run(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
|
c.onHostDnsFn = func([]string) {}
|
||||||
|
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -126,6 +134,17 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
return &PeerInfoArray{items: peerInfos}
|
return &PeerInfoArray{items: peerInfos}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||||
|
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||||
|
dnsServer, err := dns.GetServerDns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer.OnUpdatedHostDNSServer(list.items)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetConnectionListener set the network connection listener
|
// SetConnectionListener set the network connection listener
|
||||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||||
c.recorder.SetConnectionListener(listener)
|
c.recorder.SetConnectionListener(listener)
|
||||||
|
|||||||
26
client/android/dns_list.go
Normal file
26
client/android/dns_list.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// DNSList is a wrapper of []string
|
||||||
|
type DNSList struct {
|
||||||
|
items []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new DNS address to the collection
|
||||||
|
func (array *DNSList) Add(s string) {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get return an element of the collection
|
||||||
|
func (array *DNSList) Get(i int) (string, error) {
|
||||||
|
if i >= len(array.items) || i < 0 {
|
||||||
|
return "", fmt.Errorf("out of range")
|
||||||
|
}
|
||||||
|
return array.items[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size return with the size of the collection
|
||||||
|
func (array *DNSList) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
||||||
24
client/android/dns_list_test.go
Normal file
24
client/android/dns_list_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDNSList_Get(t *testing.T) {
|
||||||
|
l := DNSList{
|
||||||
|
items: make([]string, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := l.Get(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("invalid error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = l.Get(-1)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error but got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = l.Get(1)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error but got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Print(err)
|
log.Debug(err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
|||||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Error: %v\n", err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
cmd.Printf("Couldn't connect. " +
|
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||||
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
"You can verify the connection by running:\n\n" +
|
||||||
"Run the status command: \n\n" +
|
" netbird status\n\n")
|
||||||
" netbird status\n\n" +
|
return err
|
||||||
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
|
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ type Manager interface {
|
|||||||
dPort *Port,
|
dPort *Port,
|
||||||
direction RuleDirection,
|
direction RuleDirection,
|
||||||
action Action,
|
action Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (Rule, error)
|
) (Rule, error)
|
||||||
|
|
||||||
@@ -60,5 +61,8 @@ type Manager interface {
|
|||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset() error
|
||||||
|
|
||||||
|
// Flush the changes to firewall controller
|
||||||
|
Flush() error
|
||||||
|
|
||||||
// TODO: migrate routemanager firewal actions to this interface
|
// TODO: migrate routemanager firewal actions to this interface
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/nadoo/ipset"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
@@ -35,6 +36,8 @@ type Manager struct {
|
|||||||
inputDefaultRuleSpecs []string
|
inputDefaultRuleSpecs []string
|
||||||
outputDefaultRuleSpecs []string
|
outputDefaultRuleSpecs []string
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
|
rulesets map[string]ruleset
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@@ -43,6 +46,11 @@ type iFaceMapper interface {
|
|||||||
Address() iface.WGAddress
|
Address() iface.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ruleset struct {
|
||||||
|
rule *Rule
|
||||||
|
ips map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@@ -51,6 +59,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
||||||
outputDefaultRuleSpecs: []string{
|
outputDefaultRuleSpecs: []string{
|
||||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
||||||
|
rulesets: make(map[string]ruleset),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Init(); err != nil {
|
||||||
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// init clients for booth ipv4 and ipv6
|
// init clients for booth ipv4 and ipv6
|
||||||
@@ -58,13 +71,17 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||||
}
|
}
|
||||||
m.ipv4Client = ipv4Client
|
if isIptablesClientAvailable(ipv4Client) {
|
||||||
|
m.ipv4Client = ipv4Client
|
||||||
|
}
|
||||||
|
|
||||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
||||||
} else {
|
} else {
|
||||||
m.ipv6Client = ipv6Client
|
if isIptablesClientAvailable(ipv6Client) {
|
||||||
|
m.ipv6Client = ipv6Client
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Reset(); err != nil {
|
if err := m.Reset(); err != nil {
|
||||||
@@ -73,6 +90,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||||
|
_, err := client.ListChains("filter")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment is empty rule ID is used as comment
|
// If comment is empty rule ID is used as comment
|
||||||
@@ -83,6 +105,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@@ -101,22 +124,45 @@ func (m *Manager) AddFiltering(
|
|||||||
if sPort != nil && sPort.Values != nil {
|
if sPort != nil && sPort.Values != nil {
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||||
}
|
}
|
||||||
|
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleID := uuid.New().String()
|
||||||
if comment == "" {
|
if comment == "" {
|
||||||
comment = ruleID
|
comment = ruleID
|
||||||
}
|
}
|
||||||
|
|
||||||
specs := m.filterRuleSpecs(
|
if ipsetName != "" {
|
||||||
"filter",
|
rs, rsExists := m.rulesets[ipsetName]
|
||||||
ip,
|
if !rsExists {
|
||||||
string(protocol),
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
sPortVal,
|
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
||||||
dPortVal,
|
}
|
||||||
direction,
|
if err := ipset.Create(ipsetName); err != nil {
|
||||||
action,
|
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||||
comment,
|
}
|
||||||
)
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rsExists {
|
||||||
|
// if ruleset already exists it means we already have the firewall rule
|
||||||
|
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||||
|
rs.ips[ip.String()] = ruleID
|
||||||
|
return &Rule{
|
||||||
|
ruleID: ruleID,
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
|
dst: direction == fw.RuleDirectionOUT,
|
||||||
|
v6: ip.To4() == nil,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// this is new ipset so we need to create firewall rule for it
|
||||||
|
}
|
||||||
|
|
||||||
|
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
|
||||||
|
direction, action, comment, ipsetName)
|
||||||
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
||||||
@@ -144,12 +190,24 @@ func (m *Manager) AddFiltering(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Rule{
|
rule := &Rule{
|
||||||
id: ruleID,
|
ruleID: ruleID,
|
||||||
specs: specs,
|
specs: specs,
|
||||||
dst: direction == fw.RuleDirectionOUT,
|
ipsetName: ipsetName,
|
||||||
v6: ip.To4() == nil,
|
ip: ip.String(),
|
||||||
}, nil
|
dst: direction == fw.RuleDirectionOUT,
|
||||||
|
v6: ip.To4() == nil,
|
||||||
|
}
|
||||||
|
if ipsetName != "" {
|
||||||
|
// ipset name is defined and it means that this rule was created
|
||||||
|
// for it, need to assosiate it with ruleset
|
||||||
|
m.rulesets[ipsetName] = ruleset{
|
||||||
|
rule: rule,
|
||||||
|
ips: map[string]string{rule.ip: ruleID},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
@@ -170,6 +228,31 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
client = m.ipv6Client
|
client = m.ipv6Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
||||||
|
// delete IP from ruleset IPs list and ipset
|
||||||
|
if _, ok := rs.ips[r.ip]; ok {
|
||||||
|
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||||
|
}
|
||||||
|
delete(rs.ips, r.ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if after delete, set still contains other IPs,
|
||||||
|
// no need to delete firewall rule and we should exit here
|
||||||
|
if len(rs.ips) != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// we delete last IP from the set, that means we need to delete
|
||||||
|
// set itself and assosiated firewall rule too
|
||||||
|
delete(m.rulesets, r.ipsetName)
|
||||||
|
|
||||||
|
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||||
|
log.Errorf("delete empty ipset: %v", err)
|
||||||
|
}
|
||||||
|
r = rs.rule
|
||||||
|
}
|
||||||
|
|
||||||
if r.dst {
|
if r.dst {
|
||||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
||||||
}
|
}
|
||||||
@@ -193,6 +276,9 @@ func (m *Manager) Reset() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// reset firewall chain, clear it and drop it
|
// reset firewall chain, clear it and drop it
|
||||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||||
@@ -233,6 +319,16 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for ipsetName := range m.rulesets {
|
||||||
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
if err := ipset.Destroy(ipsetName); err != nil {
|
||||||
|
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
delete(m.rulesets, ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,6 +336,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
|||||||
func (m *Manager) filterRuleSpecs(
|
func (m *Manager) filterRuleSpecs(
|
||||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
table string, ip net.IP, protocol string, sPort, dPort string,
|
||||||
direction fw.RuleDirection, action fw.Action, comment string,
|
direction fw.RuleDirection, action fw.Action, comment string,
|
||||||
|
ipsetName string,
|
||||||
) (specs []string) {
|
) (specs []string) {
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// don't use IP matching if IP is ip 0.0.0.0
|
||||||
@@ -249,11 +346,19 @@ func (m *Manager) filterRuleSpecs(
|
|||||||
switch direction {
|
switch direction {
|
||||||
case fw.RuleDirectionIN:
|
case fw.RuleDirectionIN:
|
||||||
if matchByIP {
|
if matchByIP {
|
||||||
specs = append(specs, "-s", ip.String())
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||||
|
} else {
|
||||||
|
specs = append(specs, "-s", ip.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case fw.RuleDirectionOUT:
|
case fw.RuleDirectionOUT:
|
||||||
if matchByIP {
|
if matchByIP {
|
||||||
specs = append(specs, "-d", ip.String())
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||||
|
} else {
|
||||||
|
specs = append(specs, "-d", ip.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if protocol != "all" {
|
if protocol != "all" {
|
||||||
@@ -335,3 +440,16 @@ func (m *Manager) actionToStr(action fw.Action) string {
|
|||||||
}
|
}
|
||||||
return "DROP"
|
return "DROP"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||||
|
if ipsetName == "" {
|
||||||
|
return ""
|
||||||
|
} else if sPort != "" && dPort != "" {
|
||||||
|
return ipsetName + "-sport-dport"
|
||||||
|
} else if sPort != "" {
|
||||||
|
return ipsetName + "-sport"
|
||||||
|
} else if dPort != "" {
|
||||||
|
return ipsetName + "-dport"
|
||||||
|
}
|
||||||
|
return ipsetName
|
||||||
|
}
|
||||||
|
|||||||
@@ -55,12 +55,13 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
err := manager.Reset()
|
||||||
t.Errorf("clear the manager state: %v", err)
|
require.NoError(t, err, "clear the manager state")
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -68,7 +69,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
t.Run("add first rule", func(t *testing.T) {
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||||
@@ -81,33 +82,31 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
Values: []int{8043: 8046},
|
Values: []int{8043: 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
rule2, err = manager.AddFiltering(
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
if err := manager.DeleteRule(rule1); err != nil {
|
err := manager.DeleteRule(rule1)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
}
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
if err := manager.DeleteRule(rule2); err != nil {
|
err := manager.DeleteRule(rule2)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
}
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...)
|
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Values: []int{5353}}
|
||||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
|
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@@ -122,6 +121,88 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIptablesManagerIPSet(t *testing.T) {
|
||||||
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mock := &iFaceMock{
|
||||||
|
NameFunc: func() string {
|
||||||
|
return "lo"
|
||||||
|
},
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// just check on the local interface
|
||||||
|
manager, err := Create(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := manager.Reset()
|
||||||
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var rule1 fw.Rule
|
||||||
|
t.Run("add first rule with set", func(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.20.0.2")
|
||||||
|
port := &fw.Port{Values: []int{8080}}
|
||||||
|
rule1, err = manager.AddFiltering(
|
||||||
|
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||||
|
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||||
|
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||||
|
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||||
|
})
|
||||||
|
|
||||||
|
var rule2 fw.Rule
|
||||||
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.20.0.3")
|
||||||
|
port := &fw.Port{
|
||||||
|
Values: []int{443},
|
||||||
|
}
|
||||||
|
rule2, err = manager.AddFiltering(
|
||||||
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||||
|
"default", "accept HTTPS traffic from ports range",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
|
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
|
err := manager.DeleteRule(rule1)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
|
err := manager.DeleteRule(rule2)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reset check", func(t *testing.T) {
|
||||||
|
err = manager.Reset()
|
||||||
|
require.NoError(t, err, "failed to reset")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||||
require.NoError(t, err, "failed to check rule")
|
require.NoError(t, err, "failed to check rule")
|
||||||
@@ -153,9 +234,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
err := manager.Reset()
|
||||||
t.Errorf("clear the manager state: %v", err)
|
require.NoError(t, err, "clear the manager state")
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -167,9 +248,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package iptables
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
id string
|
ruleID string
|
||||||
|
ipsetName string
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
|
ip string
|
||||||
dst bool
|
dst bool
|
||||||
v6 bool
|
v6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/google/uuid"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
@@ -29,11 +31,14 @@ const (
|
|||||||
FilterOutputChainName = "netbird-acl-output-filter"
|
FilterOutputChainName = "netbird-acl-output-filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
conn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
|
sConn *nftables.Conn
|
||||||
tableIPv4 *nftables.Table
|
tableIPv4 *nftables.Table
|
||||||
tableIPv6 *nftables.Table
|
tableIPv6 *nftables.Table
|
||||||
|
|
||||||
@@ -43,6 +48,10 @@ type Manager struct {
|
|||||||
filterInputChainIPv6 *nftables.Chain
|
filterInputChainIPv6 *nftables.Chain
|
||||||
filterOutputChainIPv6 *nftables.Chain
|
filterOutputChainIPv6 *nftables.Chain
|
||||||
|
|
||||||
|
rulesetManager *rulesetManager
|
||||||
|
setRemovedIPs map[string]struct{}
|
||||||
|
setRemoved map[string]*nftables.Set
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,8 +63,23 @@ type iFaceMapper interface {
|
|||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
|
// and is permanent. Using same connection for booth type of operations
|
||||||
|
// overloads netlink with high amount of rules ( > 10000)
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
conn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
|
sConn: sConn,
|
||||||
|
|
||||||
|
rulesetManager: newRuleManager(),
|
||||||
|
setRemovedIPs: map[string]struct{}{},
|
||||||
|
setRemoved: map[string]*nftables.Set{},
|
||||||
|
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +101,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@@ -84,6 +109,7 @@ func (m *Manager) AddFiltering(
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
|
ipset *nftables.Set
|
||||||
table *nftables.Table
|
table *nftables.Table
|
||||||
chain *nftables.Chain
|
chain *nftables.Chain
|
||||||
)
|
)
|
||||||
@@ -107,6 +133,46 @@ func (m *Manager) AddFiltering(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rawIP := ip.To4()
|
||||||
|
if rawIP == nil {
|
||||||
|
rawIP = ip.To16()
|
||||||
|
}
|
||||||
|
|
||||||
|
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||||
|
|
||||||
|
if ipsetName != "" {
|
||||||
|
// if we already have set with given name, just add ip to the set
|
||||||
|
// and return rule with new ID in other case let's create rule
|
||||||
|
// with fresh created set and set element
|
||||||
|
|
||||||
|
var isSetNew bool
|
||||||
|
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
||||||
|
if err != nil {
|
||||||
|
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||||
|
return nil, fmt.Errorf("get set name: %v", err)
|
||||||
|
}
|
||||||
|
isSetNew = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||||
|
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSetNew {
|
||||||
|
// if we already have nftables rules with set for given direction
|
||||||
|
// just add new rule to the ruleset and return new fw.Rule object
|
||||||
|
|
||||||
|
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||||
|
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||||
|
}
|
||||||
|
// if ipset exists but it is not linked to rule for given direction
|
||||||
|
// create new rule for direction and bind ipset to it later
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
ifaceKey := expr.MetaKeyIIFNAME
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
ifaceKey = expr.MetaKeyOIFNAME
|
||||||
@@ -146,39 +212,47 @@ func (m *Manager) AddFiltering(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||||
if s := ip.String(); s != "0.0.0.0" && s != "::" {
|
// in that case not add IP match expression into the rule definition
|
||||||
|
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||||
// source address position
|
// source address position
|
||||||
var adrLen, adrOffset uint32
|
addrLen := uint32(len(rawIP))
|
||||||
if ip.To4() == nil {
|
addrOffset := uint32(12)
|
||||||
adrLen = 16
|
if addrLen == 16 {
|
||||||
adrOffset = 8
|
addrOffset = 8
|
||||||
} else {
|
|
||||||
adrLen = 4
|
|
||||||
adrOffset = 12
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// change to destination address position if need
|
// change to destination address position if need
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
adrOffset += adrLen
|
addrOffset += addrLen
|
||||||
}
|
}
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: adrOffset,
|
Offset: addrOffset,
|
||||||
Len: adrLen,
|
Len: addrLen,
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: add.AsSlice(),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
// add individual IP for match if no ipset defined
|
||||||
|
if ipset == nil {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: rawIP,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ipsetName,
|
||||||
|
SetID: ipset.ID,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
if sPort != nil && len(sPort.Values) != 0 {
|
||||||
@@ -219,39 +293,76 @@ func (m *Manager) AddFiltering(
|
|||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
id := uuid.New().String()
|
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||||
userData := []byte(strings.Join([]string{id, comment}, " "))
|
|
||||||
|
|
||||||
_ = m.conn.InsertRule(&nftables.Rule{
|
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Position: 0,
|
Position: 0,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
})
|
})
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
if err := m.conn.Flush(); err != nil {
|
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
list, err := m.conn.GetRules(table, chain)
|
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||||
if err != nil {
|
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||||
return nil, err
|
}
|
||||||
|
|
||||||
|
// getRulesetID returns ruleset ID based on given parameters
|
||||||
|
func (m *Manager) getRulesetID(
|
||||||
|
ip net.IP,
|
||||||
|
proto fw.Protocol,
|
||||||
|
sPort *fw.Port,
|
||||||
|
dPort *fw.Port,
|
||||||
|
direction fw.RuleDirection,
|
||||||
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) string {
|
||||||
|
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||||
|
if sPort != nil {
|
||||||
|
rulesetID += sPort.String()
|
||||||
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
if dPort != nil {
|
||||||
|
rulesetID += dPort.String()
|
||||||
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
rulesetID += strconv.Itoa(int(action))
|
||||||
|
if ipsetName == "" {
|
||||||
|
return "ip:" + ip.String() + rulesetID
|
||||||
|
}
|
||||||
|
return "set:" + ipsetName + rulesetID
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSet in given table by name
|
||||||
|
func (m *Manager) createSet(
|
||||||
|
table *nftables.Table,
|
||||||
|
rawIP []byte,
|
||||||
|
name string,
|
||||||
|
) (*nftables.Set, error) {
|
||||||
|
keyType := nftables.TypeIPAddr
|
||||||
|
if len(rawIP) == 16 {
|
||||||
|
keyType = nftables.TypeIP6Addr
|
||||||
|
}
|
||||||
|
// else we create new ipset and continue creating rule
|
||||||
|
ipset := &nftables.Set{
|
||||||
|
Name: name,
|
||||||
|
Table: table,
|
||||||
|
Dynamic: true,
|
||||||
|
KeyType: keyType,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the rule to the chain
|
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||||
rule := &Rule{id: id}
|
return nil, fmt.Errorf("create set: %v", err)
|
||||||
for _, r := range list {
|
|
||||||
if bytes.Equal(r.UserData, userData) {
|
|
||||||
rule.Rule = r
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if rule.Rule == nil {
|
|
||||||
return nil, fmt.Errorf("rule not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return rule, nil
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush created set: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// chain returns the chain for the given IP address with specific settings
|
// chain returns the chain for the given IP address with specific settings
|
||||||
@@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||||
tables, err := m.conn.ListTablesOfFamily(family)
|
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
return nil, fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
@@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
|
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return table, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createChainIfNotExists(
|
func (m *Manager) createChainIfNotExists(
|
||||||
@@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
chains, err := m.conn.ListChainsOfTableFamily(family)
|
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of chains: %w", err)
|
return nil, fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
@@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
Policy: &polAccept,
|
Policy: &polAccept,
|
||||||
}
|
}
|
||||||
|
|
||||||
chain = m.conn.AddChain(chain)
|
chain = m.rConn.AddChain(chain)
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
ifaceKey := expr.MetaKeyIIFNAME
|
||||||
shiftDSTAddr := 0
|
shiftDSTAddr := 0
|
||||||
@@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = m.conn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
@@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
},
|
},
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||||
}
|
}
|
||||||
_ = m.conn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
})
|
})
|
||||||
if err := m.conn.Flush(); err != nil {
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
nativeRule, ok := rule.(*Rule)
|
nativeRule, ok := rule.(*Rule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
|
if nativeRule.nftRule == nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.Flush()
|
if nativeRule.nftSet != nil {
|
||||||
|
// call twice of delete set element raises error
|
||||||
|
// so we need to check if element is already removed
|
||||||
|
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||||
|
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||||
|
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.setRemovedIPs[key] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.rulesetManager.deleteRule(nativeRule) {
|
||||||
|
// deleteRule indicates that we still have IP in the ruleset
|
||||||
|
// it means we should not remove the nftables rule but need to update set
|
||||||
|
// so we prepare IP to be removed from set on the next flush call
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||||
|
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
nativeRule.nftRule = nil
|
||||||
|
|
||||||
|
if nativeRule.nftSet != nil {
|
||||||
|
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||||
|
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||||
|
}
|
||||||
|
nativeRule.nftSet = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@@ -475,27 +633,116 @@ func (m *Manager) Reset() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
chains, err := m.conn.ListChains()
|
chains, err := m.rConn.ListChains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||||
m.conn.DelChain(c)
|
m.rConn.DelChain(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tables, err := m.conn.ListTables()
|
tables, err := m.rConn.ListTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
return fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == FilterTableName {
|
if t.Name == FilterTableName {
|
||||||
m.conn.DelTable(t)
|
m.rConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.Flush()
|
return m.rConn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush rule/chain/set operations from the buffer
|
||||||
|
//
|
||||||
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
func (m *Manager) Flush() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set must be removed after flush rule changes
|
||||||
|
// otherwise we will get error
|
||||||
|
for _, s := range m.setRemoved {
|
||||||
|
m.rConn.FlushSet(s)
|
||||||
|
m.rConn.DelSet(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.setRemoved) > 0 {
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.setRemovedIPs = map[string]struct{}{}
|
||||||
|
m.setRemoved = map[string]*nftables.Set{}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) flushWithBackoff() (err error) {
|
||||||
|
backoff := 4
|
||||||
|
backoffTime := 1000 * time.Millisecond
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("failed to flush nftables, retrying...")
|
||||||
|
if i == backoff-1 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(backoffTime)
|
||||||
|
backoffTime = backoffTime * 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
if table == nil || chain == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := m.rConn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range list {
|
||||||
|
if len(rule.UserData) != 0 {
|
||||||
|
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||||
|
log.Errorf("failed to set rule handle: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodePort(port fw.Port) []byte {
|
func encodePort(port fw.Port) []byte {
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
fw.RuleDirectionIN,
|
fw.RuleDirectionIN,
|
||||||
fw.ActionDrop,
|
fw.ActionDrop,
|
||||||
"",
|
"",
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
// test expectations:
|
// test expectations:
|
||||||
// 1) regular rule
|
// 1) regular rule
|
||||||
// 2) "accept extra routed traffic rule" for the interface
|
// 2) "accept extra routed traffic rule" for the interface
|
||||||
@@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
err = manager.DeleteRule(rule)
|
err = manager.DeleteRule(rule)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
// test expectations:
|
// test expectations:
|
||||||
@@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
if err := manager.Reset(); err != nil {
|
||||||
@@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
if i%100 == 0 {
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ import (
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
*nftables.Rule
|
nftRule *nftables.Rule
|
||||||
id string
|
nftSet *nftables.Set
|
||||||
|
|
||||||
|
ruleID string
|
||||||
|
ip []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
115
client/firewall/nftables/ruleset_linux.go
Normal file
115
client/firewall/nftables/ruleset_linux.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||||
|
type nftRuleset struct {
|
||||||
|
nftRule *nftables.Rule
|
||||||
|
nftSet *nftables.Set
|
||||||
|
issuedRules map[string]*Rule
|
||||||
|
rulesetID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type rulesetManager struct {
|
||||||
|
rulesets map[string]*nftRuleset
|
||||||
|
|
||||||
|
nftSetName2rulesetID map[string]string
|
||||||
|
issuedRuleID2rulesetID map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRuleManager() *rulesetManager {
|
||||||
|
return &rulesetManager{
|
||||||
|
rulesets: map[string]*nftRuleset{},
|
||||||
|
|
||||||
|
nftSetName2rulesetID: map[string]string{},
|
||||||
|
issuedRuleID2rulesetID: map[string]string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||||
|
ruleset, ok := r.rulesets[rulesetID]
|
||||||
|
return ruleset, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) createRuleset(
|
||||||
|
rulesetID string,
|
||||||
|
nftRule *nftables.Rule,
|
||||||
|
nftSet *nftables.Set,
|
||||||
|
) *nftRuleset {
|
||||||
|
ruleset := nftRuleset{
|
||||||
|
rulesetID: rulesetID,
|
||||||
|
nftRule: nftRule,
|
||||||
|
nftSet: nftSet,
|
||||||
|
issuedRules: map[string]*Rule{},
|
||||||
|
}
|
||||||
|
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||||
|
if nftSet != nil {
|
||||||
|
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||||
|
}
|
||||||
|
return &ruleset
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) addRule(
|
||||||
|
ruleset *nftRuleset,
|
||||||
|
ip []byte,
|
||||||
|
) (*Rule, error) {
|
||||||
|
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||||
|
return nil, fmt.Errorf("ruleset not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
nftRule: ruleset.nftRule,
|
||||||
|
nftSet: ruleset.nftSet,
|
||||||
|
ruleID: xid.New().String(),
|
||||||
|
ip: ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleset.issuedRules[rule.ruleID] = &rule
|
||||||
|
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteRule from ruleset and returns true if contains other rules
|
||||||
|
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||||
|
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleset := r.rulesets[rulesetID]
|
||||||
|
if ruleset.nftRule == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||||
|
delete(ruleset.issuedRules, rule.ruleID)
|
||||||
|
|
||||||
|
if len(ruleset.issuedRules) == 0 {
|
||||||
|
delete(r.rulesets, ruleset.rulesetID)
|
||||||
|
if rule.nftSet != nil {
|
||||||
|
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||||
|
//
|
||||||
|
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||||
|
// we set correct handle value to it.
|
||||||
|
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||||
|
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||||
|
ruleset, ok := r.rulesets[string(split[0])]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("ruleset not found")
|
||||||
|
}
|
||||||
|
*ruleset.nftRule = *nftRule
|
||||||
|
return nil
|
||||||
|
}
|
||||||
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{
|
||||||
|
UserData: []byte(rulesetID),
|
||||||
|
}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||||
|
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||||
|
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_addRule(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.1.1")
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||||
|
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||||
|
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||||
|
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||||
|
|
||||||
|
ruleset2 := &nftRuleset{
|
||||||
|
rulesetID: "ruleset-2",
|
||||||
|
}
|
||||||
|
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||||
|
require.Error(t, err, "addRule() should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.1.1")
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
|
||||||
|
ip2 := []byte("192.168.1.1")
|
||||||
|
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule2, "rule should not be nil")
|
||||||
|
|
||||||
|
hasNext := rulesetManager.deleteRule(rule)
|
||||||
|
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||||
|
|
||||||
|
// Check that the rule is no longer in the manager.
|
||||||
|
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||||
|
|
||||||
|
hasNext = rulesetManager.deleteRule(rule2)
|
||||||
|
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.0.1")
|
||||||
|
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
|
||||||
|
nftRuleCopy := nftRule
|
||||||
|
nftRuleCopy.Handle = 2
|
||||||
|
nftRuleCopy.UserData = []byte(rulesetID)
|
||||||
|
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||||
|
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||||
|
// check correct work with references
|
||||||
|
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
nftSet := nftables.Set{
|
||||||
|
ID: 2,
|
||||||
|
}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||||
|
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||||
|
|
||||||
|
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||||
|
require.True(t, ok, "getRuleset() failed")
|
||||||
|
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||||
|
|
||||||
|
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||||
|
require.False(t, ok, "getRuleset() failed")
|
||||||
|
}
|
||||||
@@ -21,11 +21,13 @@ type IFaceMapper interface {
|
|||||||
SetFilter(iface.PacketFilter) error
|
SetFilter(iface.PacketFilter) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RuleSet is a set of rules grouped by a string key
|
||||||
|
type RuleSet map[string]Rule
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
outgoingRules []Rule
|
outgoingRules map[string]RuleSet
|
||||||
incomingRules []Rule
|
incomingRules map[string]RuleSet
|
||||||
rulesIndex map[string]int
|
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
|
|
||||||
@@ -48,7 +50,6 @@ type decoder struct {
|
|||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface IFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
rulesIndex: make(map[string]int),
|
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
outgoingRules: make(map[string]RuleSet),
|
||||||
|
incomingRules: make(map[string]RuleSet),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
@@ -81,6 +84,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
r := Rule{
|
r := Rule{
|
||||||
@@ -124,15 +128,17 @@ func (m *Manager) AddFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
var p int
|
|
||||||
if direction == fw.RuleDirectionIN {
|
if direction == fw.RuleDirectionIN {
|
||||||
m.incomingRules = append(m.incomingRules, r)
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
p = len(m.incomingRules) - 1
|
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||||
|
}
|
||||||
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append(m.outgoingRules, r)
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
p = len(m.outgoingRules) - 1
|
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
m.rulesIndex[r.id] = p
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return &r, nil
|
return &r, nil
|
||||||
@@ -148,24 +154,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, ok := m.rulesIndex[r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.rulesIndex, r.id)
|
|
||||||
|
|
||||||
var toUpdate []Rule
|
|
||||||
if r.direction == fw.RuleDirectionIN {
|
if r.direction == fw.RuleDirectionIN {
|
||||||
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
|
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||||
toUpdate = m.incomingRules
|
if !ok {
|
||||||
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
|
}
|
||||||
|
delete(m.incomingRules[r.ip.String()], r.id)
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
|
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||||
toUpdate = m.outgoingRules
|
if !ok {
|
||||||
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
|
}
|
||||||
|
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < len(toUpdate); i++ {
|
|
||||||
m.rulesIndex[toUpdate[i].id] = i
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,13 +176,15 @@ func (m *Manager) Reset() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = m.outgoingRules[:0]
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
m.incomingRules = m.incomingRules[:0]
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
m.rulesIndex = make(map[string]int)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
return m.dropFilter(packetData, m.outgoingRules, false)
|
||||||
@@ -192,7 +196,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter imlements same logic for booth direction of the traffic
|
// dropFilter imlements same logic for booth direction of the traffic
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
|
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@@ -224,37 +228,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
payloadLayer := d.decoded[1]
|
|
||||||
|
|
||||||
// check if IP address match by IP
|
var ip net.IP
|
||||||
|
switch ipLayer {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
if isIncomingPacket {
|
||||||
|
ip = d.ip4.SrcIP
|
||||||
|
} else {
|
||||||
|
ip = d.ip4.DstIP
|
||||||
|
}
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
if isIncomingPacket {
|
||||||
|
ip = d.ip6.SrcIP
|
||||||
|
} else {
|
||||||
|
ip = d.ip6.DstIP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// default policy is DROP ALL
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||||
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP {
|
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||||
switch ipLayer {
|
continue
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if isIncomingPacket {
|
|
||||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if isIncomingPacket {
|
|
||||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
if rule.protoLayer == layerTypeAll {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
if payloadLayer != rule.protoLayer {
|
||||||
@@ -264,38 +280,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
// we ignore rule.drop and call this hook
|
// we ignore rule.drop and call this hook
|
||||||
if rule.udpHook != nil {
|
if rule.udpHook != nil {
|
||||||
return rule.udpHook(packetData)
|
return rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false, false
|
||||||
// default policy is DROP ALL
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
@@ -325,19 +339,19 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
var toUpdate []Rule
|
|
||||||
if in {
|
if in {
|
||||||
r.direction = fw.RuleDirectionIN
|
r.direction = fw.RuleDirectionIN
|
||||||
m.incomingRules = append([]Rule{r}, m.incomingRules...)
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
toUpdate = m.incomingRules
|
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
toUpdate = m.outgoingRules
|
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range toUpdate {
|
|
||||||
m.rulesIndex[toUpdate[i].id] = i
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return r.id
|
return r.id
|
||||||
@@ -345,14 +359,18 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
// RemovePacketHook removes packet hook by given ID
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||||
for _, r := range m.incomingRules {
|
for _, arr := range m.incomingRules {
|
||||||
if r.id == hookID {
|
for _, r := range arr {
|
||||||
return m.DeleteRule(&r)
|
if r.id == hookID {
|
||||||
|
return m.DeleteRule(&r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, r := range m.outgoingRules {
|
for _, arr := range m.outgoingRules {
|
||||||
if r.id == hookID {
|
for _, r := range arr {
|
||||||
return m.DeleteRule(&r)
|
if r.id == hookID {
|
||||||
|
return m.DeleteRule(&r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action = fw.ActionDrop
|
action = fw.ActionDrop
|
||||||
comment = "Test rule 2"
|
comment = "Test rule 2"
|
||||||
|
|
||||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
|
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the rulesIndex")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.DeleteRule(rule2)
|
err = m.DeleteRule(rule2)
|
||||||
@@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
|
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||||
t.Errorf("rule1 still in the rulesIndex")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
incomingRules: []Rule{},
|
incomingRules: map[string]RuleSet{},
|
||||||
outgoingRules: []Rule{},
|
outgoingRules: map[string]RuleSet{},
|
||||||
rulesIndex: make(map[string]int),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule Rule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules) != 1 {
|
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addedRule = manager.incomingRules[0]
|
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(manager.outgoingRules) != 1 {
|
if len(manager.outgoingRules) != 1 {
|
||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addedRule = manager.outgoingRules[0]
|
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if !tt.ip.Equal(addedRule.ip) {
|
||||||
@@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected udpHook to be set")
|
t.Errorf("expected udpHook to be set")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure rulesIndex is correctly updated
|
|
||||||
index, ok := manager.rulesIndex[addedRule.id]
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("expected rule to be in rulesIndex")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if index != 0 {
|
|
||||||
t.Errorf("expected rule index to be 0, got %d", index)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -244,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||||
t.Errorf("rules is not empty")
|
t.Errorf("rules is not empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -282,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
|
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -346,10 +338,12 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
for _, rule := range manager.outgoingRules {
|
for _, arr := range manager.outgoingRules {
|
||||||
if rule.id == hookID {
|
for _, rule := range arr {
|
||||||
found = true
|
if rule.id == hookID {
|
||||||
break
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -364,9 +358,11 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||||
for _, rule := range manager.outgoingRules {
|
for _, arr := range manager.outgoingRules {
|
||||||
if rule.id == hookID {
|
for _, rule := range arr {
|
||||||
t.Fatalf("The hook was not removed properly.")
|
if rule.id == hookID {
|
||||||
|
t.Fatalf("The hook was not removed properly.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -394,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|||||||
@@ -33,9 +33,22 @@ type Manager interface {
|
|||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
manager firewall.Manager
|
manager firewall.Manager
|
||||||
rulesPairs map[string][]firewall.Rule
|
ipsetCounter int
|
||||||
mutex sync.Mutex
|
rulesPairs map[string][]firewall.Rule
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipsetInfo struct {
|
||||||
|
name string
|
||||||
|
ipCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
|
return &DefaultManager{
|
||||||
|
manager: fm,
|
||||||
|
rulesPairs: make(map[string][]firewall.Rule),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||||
@@ -61,6 +74,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := d.manager.Flush(); err != nil {
|
||||||
|
log.Error("failed to flush firewall rules: ", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
enableSSH := (networkMap.PeerConfig != nil &&
|
enableSSH := (networkMap.PeerConfig != nil &&
|
||||||
@@ -108,8 +127,32 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
|
|
||||||
applyFailed := false
|
applyFailed := false
|
||||||
newRulePairs := make(map[string][]firewall.Rule)
|
newRulePairs := make(map[string][]firewall.Rule)
|
||||||
|
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||||
|
|
||||||
|
// calculate which IP's can be grouped in by which ipset
|
||||||
|
// to do that we use rule selector (which is just rule properties without IP's)
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r)
|
selector := d.getRuleGroupingSelector(r)
|
||||||
|
ipset, ok := ipsetByRuleSelectors[selector]
|
||||||
|
if !ok {
|
||||||
|
ipset = &ipsetInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ipset.ipCount++
|
||||||
|
ipsetByRuleSelectors[selector] = ipset
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||||
|
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||||
|
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||||
|
ipsetName := ""
|
||||||
|
if ipset.name == "" {
|
||||||
|
d.ipsetCounter++
|
||||||
|
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||||
|
}
|
||||||
|
ipsetName = ipset.name
|
||||||
|
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||||
applyFailed = true
|
applyFailed = true
|
||||||
@@ -154,7 +197,10 @@ func (d *DefaultManager) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) {
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
|
r *mgmProto.FirewallRule,
|
||||||
|
ipsetName string,
|
||||||
|
) (string, []firewall.Rule, error) {
|
||||||
ip := net.ParseIP(r.PeerIP)
|
ip := net.ParseIP(r.PeerIP)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
@@ -190,9 +236,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
|||||||
var err error
|
var err error
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.FirewallRule_IN:
|
case mgmProto.FirewallRule_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, "")
|
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||||
case mgmProto.FirewallRule_OUT:
|
case mgmProto.FirewallRule_OUT:
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, "")
|
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
}
|
}
|
||||||
@@ -205,9 +251,17 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
|||||||
return ruleID, rules, nil
|
return ruleID, rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
func (d *DefaultManager) addInRules(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
comment string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment)
|
rule, err := d.manager.AddFiltering(
|
||||||
|
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -217,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment)
|
rule, err = d.manager.AddFiltering(
|
||||||
|
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -225,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return append(rules, rule), nil
|
return append(rules, rule), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
func (d *DefaultManager) addOutRules(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
comment string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment)
|
rule, err := d.manager.AddFiltering(
|
||||||
|
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -237,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment)
|
rule, err = d.manager.AddFiltering(
|
||||||
|
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -282,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
in := protoMatch{}
|
in := protoMatch{}
|
||||||
out := protoMatch{}
|
out := protoMatch{}
|
||||||
|
|
||||||
|
// trace which type of protocols was squashed
|
||||||
|
squashedRules := []*mgmProto.FirewallRule{}
|
||||||
|
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||||
// But we zeroed the IP's for protocol if:
|
// But we zeroed the IP's for protocol if:
|
||||||
@@ -298,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = map[string]int{}
|
||||||
}
|
}
|
||||||
match := protocols[r.Protocol]
|
|
||||||
|
|
||||||
if _, ok := match[r.PeerIP]; ok {
|
// special case, when we recieve this all network IP address
|
||||||
|
// it means that rules for that protocol was already optimized on the
|
||||||
|
// management side
|
||||||
|
if r.PeerIP == "0.0.0.0" {
|
||||||
|
squashedRules = append(squashedRules, r)
|
||||||
|
squashedProtocols[r.Protocol] = struct{}{}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
match[r.PeerIP] = i
|
|
||||||
|
ipset := protocols[r.Protocol]
|
||||||
|
|
||||||
|
if _, ok := ipset[r.PeerIP]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipset[r.PeerIP] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
for i, r := range networkMap.FirewallRules {
|
||||||
@@ -324,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
mgmProto.FirewallRule_UDP,
|
mgmProto.FirewallRule_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
|
||||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||||
for _, protocol := range protocolOrders {
|
for _, protocol := range protocolOrders {
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||||
@@ -382,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
return append(rules, squashedRules...), squashedProtocols
|
return append(rules, squashedRules...), squashedProtocols
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||||
|
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||||
|
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||||
|
}
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case mgmProto.FirewallRule_TCP:
|
case mgmProto.FirewallRule_TCP:
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &DefaultManager{
|
return newDefaultManager(fm), nil
|
||||||
manager: fm,
|
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DefaultManager{
|
return newDefaultManager(fm), nil
|
||||||
manager: fm,
|
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -215,10 +215,12 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||||
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
if *input.PreSharedKey != "" {
|
||||||
*input.PreSharedKey, config.PreSharedKey)
|
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||||
config.PreSharedKey = *input.PreSharedKey
|
*input.PreSharedKey, config.PreSharedKey)
|
||||||
refresh = true
|
config.PreSharedKey = *input.PreSharedKey
|
||||||
|
refresh = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.SSHKey == "" {
|
if config.SSHKey == "" {
|
||||||
|
|||||||
@@ -63,7 +63,22 @@ func TestGetConfig(t *testing.T) {
|
|||||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||||
|
|
||||||
// case 4: existing config, but new managementURL has been provided -> update config
|
// case 4: new empty pre-shared key config -> fetch it
|
||||||
|
newPreSharedKey := ""
|
||||||
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ManagementURL: managementURL,
|
||||||
|
AdminURL: adminURL,
|
||||||
|
ConfigPath: path,
|
||||||
|
PreSharedKey: &newPreSharedKey,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||||
|
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||||
|
|
||||||
|
// case 5: existing config, but new managementURL has been provided -> update config
|
||||||
newManagementURL := "https://test.newManagement.url:33071"
|
newManagementURL := "https://test.newManagement.url:33071"
|
||||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
ManagementURL: newManagementURL,
|
ManagementURL: newManagementURL,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -24,7 +25,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||||
|
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunClientMobile with main logic on mobile system
|
||||||
|
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
||||||
|
// in case of non Android os these variables will be nil
|
||||||
|
mobileDependency := MobileDependency{
|
||||||
|
TunAdapter: tunAdapter,
|
||||||
|
IFaceDiscover: iFaceDiscover,
|
||||||
|
RouteListener: routeListener,
|
||||||
|
HostDNSAddresses: dnsAddresses,
|
||||||
|
DnsReadyListener: dnsReadyListener,
|
||||||
|
}
|
||||||
|
return runClient(ctx, config, statusRecorder, mobileDependency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// in case of non Android os these variables will be nil
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
||||||
md := MobileDependency{
|
|
||||||
TunAdapter: tunAdapter,
|
|
||||||
IFaceDiscover: iFaceDiscover,
|
|
||||||
RouteListener: routeListener,
|
|
||||||
}
|
|
||||||
|
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
|
||||||
err = engine.Start()
|
err = engine.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
type androidHostManager struct {
|
type androidHostManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
return &androidHostManager{}, nil
|
return &androidHostManager{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -34,7 +32,7 @@ type systemConfigurator struct {
|
|||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(_ *iface.WGIface) (hostManager, error) {
|
func newHostManager(_ WGIface) (hostManager, error) {
|
||||||
return &systemConfigurator{
|
return &systemConfigurator{
|
||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,7 +25,7 @@ const (
|
|||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -33,7 +31,7 @@ type registryConfigurator struct {
|
|||||||
existingSearchDomains []string
|
existingSearchDomains []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
if m.UpdateDNSServerFunc != nil {
|
if m.UpdateDNSServerFunc != nil {
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ import (
|
|||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const resolvconfCommand = "resolvconf"
|
const resolvconfCommand = "resolvconf"
|
||||||
@@ -18,7 +16,7 @@ type resolvconf struct {
|
|||||||
ifaceName string
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
return &resolvconf{
|
return &resolvconf{
|
||||||
ifaceName: wgInterface.Name(),
|
ifaceName: wgInterface.Name(),
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@@ -3,29 +3,20 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
defaultPort = 53
|
type ReadyListener interface {
|
||||||
customPort = 5053
|
OnReady()
|
||||||
defaultIP = "127.0.0.1"
|
}
|
||||||
customIP = "127.0.0.153"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
@@ -33,6 +24,7 @@ type Server interface {
|
|||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
|
OnUpdatedHostDNSServer(strings []string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
type registeredHandlerMap map[string]handlerWithStop
|
||||||
@@ -42,21 +34,19 @@ type DefaultServer struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
udpFilterHookID string
|
service service
|
||||||
server *dns.Server
|
|
||||||
dnsMux *dns.ServeMux
|
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface *iface.WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
listenerIsRunning bool
|
|
||||||
runtimePort int
|
|
||||||
runtimeIP string
|
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig hostDNSConfig
|
currentConfig hostDNSConfig
|
||||||
customAddress *netip.AddrPort
|
|
||||||
enabled bool
|
// permanent related properties
|
||||||
|
permanent bool
|
||||||
|
hostsDnsList []string
|
||||||
|
hostsDnsListLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
@@ -70,9 +60,7 @@ type muxUpdate struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
|
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
@@ -82,37 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
|||||||
addrPort = &parsedAddrPort
|
addrPort = &parsedAddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, stop := context.WithCancel(ctx)
|
var dnsService service
|
||||||
|
if wgInterface.IsUserspaceBind() {
|
||||||
|
dnsService = newServiceViaMemory(wgInterface)
|
||||||
|
} else {
|
||||||
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
|
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
|
||||||
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
|
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||||
|
ds.permanent = true
|
||||||
|
ds.hostsDnsList = hostsDnsList
|
||||||
|
ds.addHostRootZone()
|
||||||
|
setServerDns(ds)
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||||
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
server: &dns.Server{
|
service: dnsService,
|
||||||
Net: "udp",
|
|
||||||
Handler: mux,
|
|
||||||
UDPSize: 65535,
|
|
||||||
},
|
|
||||||
dnsMux: mux,
|
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
customAddress: addrPort,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if initialDnsCfg != nil {
|
return defaultServer
|
||||||
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if wgInterface.IsUserspaceBind() {
|
|
||||||
defaultServer.evelRuntimeAddressForUserspace()
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultServer, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize instantiate host manager. It required to be initialized wginterface
|
// Initialize instantiate host manager and the dns service
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
func (s *DefaultServer) Initialize() (err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@@ -121,72 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.wgInterface.IsUserspaceBind() {
|
if s.permanent {
|
||||||
s.evalRuntimeAddress()
|
err = s.service.Listen()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.hostManager, err = newHostManager(s.wgInterface)
|
s.hostManager, err = newHostManager(s.wgInterface)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen runs the listener in a go routine
|
|
||||||
func (s *DefaultServer) listen() {
|
|
||||||
// nil check required in unit tests
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
|
||||||
s.udpFilterHookID = s.filterDNSTraffic()
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
err := s.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DnsIP returns the DNS resolver server IP address
|
// DnsIP returns the DNS resolver server IP address
|
||||||
//
|
//
|
||||||
// When kernel space interface used it return real DNS server listener IP address
|
// When kernel space interface used it return real DNS server listener IP address
|
||||||
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
||||||
func (s *DefaultServer) DnsIP() string {
|
func (s *DefaultServer) DnsIP() string {
|
||||||
if !s.enabled {
|
return s.service.RuntimeIP()
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return s.runtimeIP
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
|
||||||
ips := []string{defaultIP, customIP}
|
|
||||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
|
||||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
|
||||||
}
|
|
||||||
ports := []int{defaultPort, customPort}
|
|
||||||
for _, port := range ports {
|
|
||||||
for _, ip := range ips {
|
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err == nil {
|
|
||||||
err = probeListener.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
|
||||||
}
|
|
||||||
return ip, port, nil
|
|
||||||
}
|
|
||||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
|
||||||
s.listenerIsRunning = running
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the server
|
// Stop stops the server
|
||||||
@@ -202,37 +148,23 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.stopListener()
|
s.service.Stop()
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) stopListener() error {
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
// udpFilterHookID here empty only in the unit tests
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||||
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
|
s.hostsDnsListLock.Lock()
|
||||||
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
|
defer s.hostsDnsListLock.Unlock()
|
||||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.udpFilterHookID = ""
|
|
||||||
s.listenerIsRunning = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.listenerIsRunning {
|
s.hostsDnsList = hostsDnsList
|
||||||
return nil
|
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||||
|
if ok {
|
||||||
|
log.Debugf("on new host DNS config but skip to apply it")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
s.addHostRootZone()
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := s.server.ShutdownContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
// UpdateDNSServer processes an update received from the management service
|
||||||
@@ -283,12 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
// is the service should be disabled, we stop the listener or fake resolver
|
// is the service should be disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if !update.ServiceEnable {
|
if update.ServiceEnable {
|
||||||
if err := s.stopListener(); err != nil {
|
_ = s.service.Listen()
|
||||||
log.Error(err)
|
} else if !s.permanent {
|
||||||
}
|
s.service.Stop()
|
||||||
} else if !s.listenerIsRunning {
|
|
||||||
s.listen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
@@ -299,15 +229,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
s.updateMux(muxUpdates)
|
||||||
s.updateLocalResolver(localRecords)
|
s.updateLocalResolver(localRecords)
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
hostUpdate := s.currentConfig
|
||||||
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
|
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
||||||
hostUpdate.routeAll = false
|
hostUpdate.routeAll = false
|
||||||
@@ -412,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
|
|
||||||
|
var isContainRootUpdate bool
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.registerMux(update.domain, update.handler)
|
s.service.RegisterMux(update.domain, update.handler)
|
||||||
muxUpdateMap[update.domain] = update.handler
|
muxUpdateMap[update.domain] = update.handler
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if update.domain == nbdns.RootZone {
|
||||||
|
isContainRootUpdate = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
for key, existingHandler := range s.dnsMuxMap {
|
||||||
_, found := muxUpdateMap[key]
|
_, found := muxUpdateMap[key]
|
||||||
if !found {
|
if !found {
|
||||||
existingHandler.stop()
|
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||||
s.deregisterMux(key)
|
s.hostsDnsListLock.Lock()
|
||||||
|
s.addHostRootZone()
|
||||||
|
s.hostsDnsListLock.Unlock()
|
||||||
|
existingHandler.stop()
|
||||||
|
} else {
|
||||||
|
existingHandler.stop()
|
||||||
|
s.service.DeregisterMux(key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
|||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
|
||||||
s.dnsMux.HandleRemove(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
@@ -490,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
for i, item := range s.currentConfig.domains {
|
for i, item := range s.currentConfig.domains {
|
||||||
if _, found := removeIndex[item.domain]; found {
|
if _, found := removeIndex[item.domain]; found {
|
||||||
s.currentConfig.domains[i].disabled = true
|
s.currentConfig.domains[i].disabled = true
|
||||||
s.deregisterMux(item.domain)
|
s.service.DeregisterMux(item.domain)
|
||||||
removeIndex[item.domain] = i
|
removeIndex[item.domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -507,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.domains[i].disabled = false
|
s.currentConfig.domains[i].disabled = false
|
||||||
s.registerMux(domain, handler)
|
s.service.RegisterMux(domain, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@@ -523,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) filterDNSTraffic() string {
|
func (s *DefaultServer) addHostRootZone() {
|
||||||
filter := s.wgInterface.GetFilter()
|
handler := newUpstreamResolver(s.ctx)
|
||||||
if filter == nil {
|
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||||
log.Error("can't set DNS filter, filter not initialized")
|
for n, ua := range s.hostsDnsList {
|
||||||
return ""
|
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
|
||||||
}
|
}
|
||||||
|
handler.deactivate = func() {}
|
||||||
firstLayerDecoder := layers.LayerTypeIPv4
|
handler.reactivate = func() {}
|
||||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||||
firstLayerDecoder = layers.LayerTypeIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
hook := func(packetData []byte) bool {
|
|
||||||
// Decode the packet
|
|
||||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
|
||||||
|
|
||||||
// Get the UDP layer
|
|
||||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
||||||
udp := udpLayer.(*layers.UDP)
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(udp.Payload); err != nil {
|
|
||||||
log.Tracef("parse DNS request: %v", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := responseWriter{
|
|
||||||
packet: packet,
|
|
||||||
device: s.wgInterface.GetDevice().Device,
|
|
||||||
}
|
|
||||||
go s.dnsMux.ServeDNS(&writer, msg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
|
|
||||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
|
||||||
s.runtimePort = defaultPort
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) evalRuntimeAddress() {
|
|
||||||
defer func() {
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if s.customAddress != nil {
|
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, port, err := s.getFirstListenerAvailable()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtimeIP = ip
|
|
||||||
s.runtimePort = port
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
|
||||||
// Calculate the last IP in the CIDR range
|
|
||||||
var endIP net.IP
|
|
||||||
for i := 0; i < len(network.IP); i++ {
|
|
||||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert to big.Int
|
|
||||||
endInt := big.NewInt(0)
|
|
||||||
endInt.SetBytes(endIP)
|
|
||||||
|
|
||||||
// subtract fromEnd from the last ip
|
|
||||||
fromEndBig := big.NewInt(int64(fromEnd))
|
|
||||||
resultInt := big.NewInt(0)
|
|
||||||
resultInt.Sub(endInt, fromEndBig)
|
|
||||||
|
|
||||||
return net.IP(resultInt.Bytes()).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasValidDnsServer(cfg *nbdns.Config) bool {
|
|
||||||
for _, c := range cfg.NameServerGroups {
|
|
||||||
if c.Primary {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
29
client/internal/dns/server_export.go
Normal file
29
client/internal/dns/server_export.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mutex sync.Mutex
|
||||||
|
server Server
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
|
||||||
|
func GetServerDns() (Server, error) {
|
||||||
|
mutex.Lock()
|
||||||
|
if server == nil {
|
||||||
|
mutex.Unlock()
|
||||||
|
return nil, fmt.Errorf("DNS server not instantiated yet")
|
||||||
|
}
|
||||||
|
s := server
|
||||||
|
mutex.Unlock()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setServerDns(newServerServer Server) {
|
||||||
|
mutex.Lock()
|
||||||
|
server = newServerServer
|
||||||
|
defer mutex.Unlock()
|
||||||
|
}
|
||||||
24
client/internal/dns/server_export_test.go
Normal file
24
client/internal/dns/server_export_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetServerDns(t *testing.T) {
|
||||||
|
_, err := GetServerDns()
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("invalid dns server instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := &MockServer{}
|
||||||
|
setServerDns(srv)
|
||||||
|
|
||||||
|
srvB, err := GetServerDns()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("invalid dns server instance: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srvB != srv {
|
||||||
|
t.Errorf("missmatch dns instances")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,14 +11,53 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/miekg/dns"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mocWGIface struct {
|
||||||
|
filter iface.PacketFilter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) Name() string {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) Address() iface.WGAddress {
|
||||||
|
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: ip,
|
||||||
|
Network: network,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||||
|
return w.filter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) IsUserspaceBind() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||||
|
w.filter = filter
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
{
|
{
|
||||||
Name: "peera.netbird.cloud",
|
Name: "peera.netbird.cloud",
|
||||||
@@ -29,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
log.SetLevel(log.TraceLevel)
|
||||||
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@@ -224,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -242,8 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
// pretend we are running
|
|
||||||
dnsServer.listenerIsRunning = true
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -282,7 +324,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet(nil)
|
newNet, err := stdnet.NewNet(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create stdnet: %v", err)
|
t.Errorf("create stdnet: %v", err)
|
||||||
@@ -316,17 +358,17 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
t.Errorf("set packet filter: %v", err)
|
t.Errorf("set packet filter: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@@ -421,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
||||||
|
if err != nil {
|
||||||
dnsServer.hostManager = newNoopHostMocker()
|
t.Fatalf("%v", err)
|
||||||
dnsServer.listen()
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
if !dnsServer.listenerIsRunning {
|
|
||||||
t.Fatal("dns server listener is not running")
|
|
||||||
}
|
}
|
||||||
|
dnsServer.hostManager = newNoopHostMocker()
|
||||||
|
err = dnsServer.service.Listen()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dns server is not running: %s", err)
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
defer dnsServer.Stop()
|
defer dnsServer.Stop()
|
||||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
||||||
|
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
@@ -443,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
d := net.Dialer{
|
d := net.Dialer{
|
||||||
Timeout: time.Second * 5,
|
Timeout: time.Second * 5,
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
conn, err := d.DialContext(ctx, network, addr)
|
conn, err := d.DialContext(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
@@ -478,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
hostManager := &mockHostConfigurator{}
|
hostManager := &mockHostConfigurator{}
|
||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
dnsMux: dns.DefaultServeMux,
|
service: newServiceViaMemory(&mocWGIface{}),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@@ -541,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||||
mux := dns.NewServeMux()
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to initialize wg interface")
|
||||||
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
var parsedAddrPort *netip.AddrPort
|
var dnsList []string
|
||||||
if addrPort != "" {
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
|
||||||
parsed, err := netip.ParseAddrPort(addrPort)
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
}
|
return
|
||||||
parsedAddrPort = &parsed
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
|
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
||||||
|
|
||||||
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||||
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to initialize wg interface")
|
||||||
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
|
// check initial state
|
||||||
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
update := nbdns.Config{
|
||||||
Net: "udp",
|
ServiceEnable: true,
|
||||||
Handler: mux,
|
CustomZones: []nbdns.CustomZone{
|
||||||
UDPSize: 65535,
|
{
|
||||||
}
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
},
|
||||||
|
},
|
||||||
ds := &DefaultServer{
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
ctx: ctx,
|
{
|
||||||
ctxCancel: cancel,
|
NameServers: []nbdns.NameServer{
|
||||||
server: dnsServer,
|
{
|
||||||
dnsMux: mux,
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
NSType: nbdns.UDPNameServerType,
|
||||||
localResolver: &localResolver{
|
Port: 53,
|
||||||
registeredMap: make(registrationMap),
|
},
|
||||||
|
},
|
||||||
|
Enabled: true,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
customAddress: parsedAddrPort,
|
|
||||||
}
|
|
||||||
ds.evalRuntimeAddress()
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
addr string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
|
||||||
{"192.168.0.0/30", "192.168.0.2"},
|
|
||||||
{"192.168.0.0/16", "192.168.255.254"},
|
|
||||||
{"192.168.0.0/24", "192.168.0.254"},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
err = dnsServer.UpdateDNSServer(1, update)
|
||||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
if err != nil {
|
||||||
if err != nil {
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
t.Errorf("Error parsing CIDR: %v", err)
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
if lastIP != tt.ip {
|
if err != nil {
|
||||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
t.Errorf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
update2 := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(2, update2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||||
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to initialize wg interface")
|
||||||
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
|
// check initial state
|
||||||
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Domains: []string{"customdomain.com"},
|
||||||
|
Primary: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(1, update)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||||
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
|
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
|
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
newNet, err := stdnet.NewNet(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create stdnet: %v", err)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build interface wireguard: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("crate and init wireguard interface: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pf, err := uspfilter.Create(wgIface)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.SetFilter(pf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("set packet filter: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgIface, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDnsResolver(ip string, port int) *net.Resolver {
|
||||||
|
return &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
d := net.Dialer{
|
||||||
|
Timeout: time.Second * 3,
|
||||||
|
}
|
||||||
|
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
return d.DialContext(ctx, network, addr)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
18
client/internal/dns/service.go
Normal file
18
client/internal/dns/service.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPort = 53
|
||||||
|
)
|
||||||
|
|
||||||
|
type service interface {
|
||||||
|
Listen() error
|
||||||
|
Stop()
|
||||||
|
RegisterMux(domain string, handler dns.Handler)
|
||||||
|
DeregisterMux(key string)
|
||||||
|
RuntimePort() int
|
||||||
|
RuntimeIP() string
|
||||||
|
}
|
||||||
145
client/internal/dns/service_listener.go
Normal file
145
client/internal/dns/service_listener.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
customPort = 5053
|
||||||
|
defaultIP = "127.0.0.1"
|
||||||
|
customIP = "127.0.0.153"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serviceViaListener struct {
|
||||||
|
wgInterface WGIface
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
customAddr *netip.AddrPort
|
||||||
|
server *dns.Server
|
||||||
|
runtimeIP string
|
||||||
|
runtimePort int
|
||||||
|
listenerIsRunning bool
|
||||||
|
listenerFlagLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
s := &serviceViaListener{
|
||||||
|
wgInterface: wgIface,
|
||||||
|
dnsMux: mux,
|
||||||
|
customAddr: customAddr,
|
||||||
|
server: &dns.Server{
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
UDPSize: 65535,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) Listen() error {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to eval runtime address: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
err := s.server.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) Stop() {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := s.server.ShutdownContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) DeregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RuntimePort() int {
|
||||||
|
return s.runtimePort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RuntimeIP() string {
|
||||||
|
return s.runtimeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||||
|
s.listenerIsRunning = running
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
||||||
|
ips := []string{defaultIP, customIP}
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||||
|
}
|
||||||
|
ports := []int{defaultPort, customPort}
|
||||||
|
for _, port := range ports {
|
||||||
|
for _, ip := range ips {
|
||||||
|
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err == nil {
|
||||||
|
err = probeListener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
|
}
|
||||||
|
return ip, port, nil
|
||||||
|
}
|
||||||
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
|
||||||
|
if s.customAddr != nil {
|
||||||
|
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.getFirstListenerAvailable()
|
||||||
|
}
|
||||||
139
client/internal/dns/service_memory.go
Normal file
139
client/internal/dns/service_memory.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serviceViaMemory struct {
|
||||||
|
wgInterface WGIface
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
runtimeIP string
|
||||||
|
runtimePort int
|
||||||
|
udpFilterHookID string
|
||||||
|
listenerIsRunning bool
|
||||||
|
listenerFlagLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||||
|
s := &serviceViaMemory{
|
||||||
|
wgInterface: wgIface,
|
||||||
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
|
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
||||||
|
runtimePort: defaultPort,
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) Listen() error {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.listenerIsRunning = true
|
||||||
|
|
||||||
|
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) Stop() {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||||
|
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.listenerIsRunning = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RuntimePort() int {
|
||||||
|
return s.runtimePort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RuntimeIP() string {
|
||||||
|
return s.runtimeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||||
|
filter := s.wgInterface.GetFilter()
|
||||||
|
if filter == nil {
|
||||||
|
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
firstLayerDecoder := layers.LayerTypeIPv4
|
||||||
|
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||||
|
firstLayerDecoder = layers.LayerTypeIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
hook := func(packetData []byte) bool {
|
||||||
|
// Decode the packet
|
||||||
|
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||||
|
|
||||||
|
// Get the UDP layer
|
||||||
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||||
|
udp := udpLayer.(*layers.UDP)
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(udp.Payload); err != nil {
|
||||||
|
log.Tracef("parse DNS request: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := responseWriter{
|
||||||
|
packet: packet,
|
||||||
|
device: s.wgInterface.GetDevice().Device,
|
||||||
|
}
|
||||||
|
go s.dnsMux.ServeDNS(&writer, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||||
|
// Calculate the last IP in the CIDR range
|
||||||
|
var endIP net.IP
|
||||||
|
for i := 0; i < len(network.IP); i++ {
|
||||||
|
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to big.Int
|
||||||
|
endInt := big.NewInt(0)
|
||||||
|
endInt.SetBytes(endIP)
|
||||||
|
|
||||||
|
// subtract fromEnd from the last ip
|
||||||
|
fromEndBig := big.NewInt(int64(fromEnd))
|
||||||
|
resultInt := big.NewInt(0)
|
||||||
|
resultInt.Sub(endInt, fromEndBig)
|
||||||
|
|
||||||
|
return net.IP(resultInt.Bytes()).String()
|
||||||
|
}
|
||||||
31
client/internal/dns/service_memory_test.go
Normal file
31
client/internal/dns/service_memory_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
addr string
|
||||||
|
ip string
|
||||||
|
}{
|
||||||
|
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||||
|
{"192.168.0.0/30", "192.168.0.2"},
|
||||||
|
{"192.168.0.0/16", "192.168.255.254"},
|
||||||
|
{"192.168.0.0/24", "192.168.0.254"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error parsing CIDR: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||||
|
if lastIP != tt.ip {
|
||||||
|
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
|
|||||||
MatchOnly bool
|
MatchOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
14
client/internal/dns/wgiface.go
Normal file
14
client/internal/dns/wgiface.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/iface"
|
||||||
|
|
||||||
|
// WGIface defines subset methods of interface required for manager
|
||||||
|
type WGIface interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
GetFilter() iface.PacketFilter
|
||||||
|
GetDevice() *iface.DeviceWrapper
|
||||||
|
}
|
||||||
13
client/internal/dns/wgiface_windows.go
Normal file
13
client/internal/dns/wgiface_windows.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/iface"
|
||||||
|
|
||||||
|
// WGIface defines subset methods of interface required for manager
|
||||||
|
type WGIface interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
GetFilter() iface.PacketFilter
|
||||||
|
GetDevice() *iface.DeviceWrapper
|
||||||
|
GetInterfaceGUIDString() (string, error)
|
||||||
|
}
|
||||||
@@ -190,23 +190,25 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var routes []*route.Route
|
var routes []*route.Route
|
||||||
var dnsCfg *nbdns.Config
|
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
routes, dnsCfg, err = e.readInitialSettings()
|
routes, err = e.readInitialSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
if e.dnsServer == nil {
|
||||||
|
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
|
||||||
if e.dnsServer == nil {
|
go e.mobileDep.DnsReadyListener.OnReady()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// todo fix custom address
|
// todo fix custom address
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg)
|
if e.dnsServer == nil {
|
||||||
if err != nil {
|
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||||
e.close()
|
if err != nil {
|
||||||
return err
|
e.close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
||||||
@@ -1045,14 +1047,13 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, error) {
|
||||||
netMap, err := e.mgmClient.GetNetworkMap()
|
netMap, err := e.mgmClient.GetNetworkMap()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
routes := toRoutes(netMap.GetRoutes())
|
||||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
return routes, nil
|
||||||
return routes, &dnsCfg, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@@ -8,7 +9,9 @@ import (
|
|||||||
|
|
||||||
// MobileDependency collect all dependencies for mobile platform
|
// MobileDependency collect all dependencies for mobile platform
|
||||||
type MobileDependency struct {
|
type MobileDependency struct {
|
||||||
TunAdapter iface.TunAdapter
|
TunAdapter iface.TunAdapter
|
||||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
RouteListener routemanager.RouteListener
|
RouteListener routemanager.RouteListener
|
||||||
|
HostDNSAddresses []string
|
||||||
|
DnsReadyListener dns.ReadyListener
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/nftables"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,33 +28,13 @@ func genKey(format string, input string) string {
|
|||||||
|
|
||||||
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||||
func NewFirewall(parentCTX context.Context) firewallManager {
|
func NewFirewall(parentCTX context.Context) firewallManager {
|
||||||
ctx, cancel := context.WithCancel(parentCTX)
|
manager, err := newNFTablesManager(parentCTX)
|
||||||
|
if err == nil {
|
||||||
if isIptablesSupported() {
|
log.Debugf("nftables firewall manager will be used")
|
||||||
log.Debugf("iptables is supported")
|
return manager
|
||||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
|
|
||||||
return &iptablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
ipv4Client: ipv4Client,
|
|
||||||
ipv6Client: ipv6Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
log.Debugf("fallback to iptables firewall manager: %s", err)
|
||||||
log.Debugf("iptables is not supported, using nftables")
|
return newIptablesManager(parentCTX)
|
||||||
|
|
||||||
manager := &nftablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
|
||||||
|
|
||||||
return manager
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInPair(pair routerPair) routerPair {
|
func getInPair(pair routerPair) routerPair {
|
||||||
|
|||||||
@@ -49,6 +49,28 @@ type iptablesManager struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newIptablesManager(parentCtx context.Context) *iptablesManager {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if !isIptablesClientAvailable(ipv4Client) {
|
||||||
|
log.Infof("iptables is missing for ipv4")
|
||||||
|
ipv4Client = nil
|
||||||
|
}
|
||||||
|
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if !isIptablesClientAvailable(ipv6Client) {
|
||||||
|
log.Infof("iptables is missing for ipv6")
|
||||||
|
ipv6Client = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &iptablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
ipv4Client: ipv4Client,
|
||||||
|
ipv6Client: ipv6Client,
|
||||||
|
rules: make(map[string]map[string][]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
||||||
func (i *iptablesManager) CleanRoutingRules() {
|
func (i *iptablesManager) CleanRoutingRules() {
|
||||||
i.mux.Lock()
|
i.mux.Lock()
|
||||||
@@ -61,24 +83,28 @@ func (i *iptablesManager) CleanRoutingRules() {
|
|||||||
|
|
||||||
log.Debug("flushing tables")
|
log.Debug("flushing tables")
|
||||||
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
||||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
if i.ipv4Client != nil {
|
||||||
if err != nil {
|
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
if err != nil {
|
||||||
|
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
if i.ipv6Client != nil {
|
||||||
if err != nil {
|
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
if err != nil {
|
||||||
}
|
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||||
|
}
|
||||||
|
|
||||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("done cleaning up iptables rules")
|
log.Info("done cleaning up iptables rules")
|
||||||
@@ -96,37 +122,41 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
|||||||
|
|
||||||
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
||||||
|
|
||||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
if i.ipv4Client != nil {
|
||||||
if err != nil {
|
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.restoreRules(i.ipv4Client)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
|
if i.ipv6Client != nil {
|
||||||
if err != nil {
|
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.restoreRules(i.ipv6Client)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
err := i.addJumpRules()
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.restoreRules(i.ipv4Client)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.restoreRules(i.ipv6Client)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.addJumpRules()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
||||||
}
|
}
|
||||||
@@ -140,34 +170,38 @@ func (i *iptablesManager) addJumpRules() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
if i.ipv4Client != nil {
|
||||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
||||||
if err != nil {
|
|
||||||
return err
|
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
i.rules[ipv4][ipv4Forwarding] = rule
|
||||||
|
|
||||||
|
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
||||||
|
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
i.rules[ipv4][ipv4Nat] = rule
|
||||||
}
|
}
|
||||||
|
|
||||||
i.rules[ipv4][ipv4Forwarding] = rule
|
if i.ipv6Client != nil {
|
||||||
|
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
||||||
|
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
i.rules[ipv6][ipv6Forwarding] = rule
|
||||||
|
|
||||||
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
rule = append(iptablesDefaultNatRule, ipv6Nat)
|
||||||
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
i.rules[ipv6][ipv6Nat] = rule
|
||||||
}
|
}
|
||||||
i.rules[ipv4][ipv4Nat] = rule
|
|
||||||
|
|
||||||
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
|
||||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[ipv6][ipv6Forwarding] = rule
|
|
||||||
|
|
||||||
rule = append(iptablesDefaultNatRule, ipv6Nat)
|
|
||||||
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[ipv6][ipv6Nat] = rule
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -177,35 +211,39 @@ func (i *iptablesManager) cleanJumpRules() error {
|
|||||||
var err error
|
var err error
|
||||||
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
||||||
rule, found := i.rules[ipv4][ipv4Forwarding]
|
rule, found := i.rules[ipv4][ipv4Forwarding]
|
||||||
if found {
|
if i.ipv4Client != nil {
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
if found {
|
||||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
||||||
if err != nil {
|
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rule, found = i.rules[ipv4][ipv4Nat]
|
||||||
|
if found {
|
||||||
|
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
|
||||||
|
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rule, found = i.rules[ipv4][ipv4Nat]
|
if i.ipv6Client == nil {
|
||||||
if found {
|
rule, found = i.rules[ipv6][ipv6Forwarding]
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
|
if found {
|
||||||
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
||||||
if err != nil {
|
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
rule, found = i.rules[ipv6][ipv6Nat]
|
||||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
if found {
|
||||||
if found {
|
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||||
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
if err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
rule, found = i.rules[ipv6][ipv6Nat]
|
|
||||||
if found {
|
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
|
|
||||||
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -437,3 +475,8 @@ func getIptablesRuleType(table string) string {
|
|||||||
}
|
}
|
||||||
return ruleType
|
return ruleType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||||
|
_, err := client.ListChains("filter")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,17 +16,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
manager := newIptablesManager(context.TODO())
|
||||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
|
|
||||||
manager := &iptablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
ipv4Client: ipv4Client,
|
|
||||||
ipv6Client: ipv6Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
@@ -37,21 +27,21 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
|
|
||||||
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
||||||
|
|
||||||
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
require.True(t, exists, "forwarding rule should exist")
|
||||||
|
|
||||||
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting rule should exist")
|
||||||
|
|
||||||
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
||||||
|
|
||||||
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
require.True(t, exists, "forwarding rule should exist")
|
||||||
|
|
||||||
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting rule should exist")
|
||||||
|
|
||||||
@@ -64,13 +54,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
||||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
nat4RuleKey := genKey(natFormat, pair.ID)
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
pair = routerPair{
|
pair = routerPair{
|
||||||
@@ -83,13 +73,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
||||||
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
nat6RuleKey := genKey(natFormat, pair.ID)
|
||||||
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
delete(manager.rules, ipv4)
|
delete(manager.rules, ipv4)
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ const (
|
|||||||
nftablesTable = "netbird-rt"
|
nftablesTable = "netbird-rt"
|
||||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
||||||
nftablesRoutingNatChain = "netbird-rt-nat"
|
nftablesRoutingNatChain = "netbird-rt-nat"
|
||||||
|
|
||||||
|
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||||
|
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to create nftable rules
|
// constants needed to create nftable rules
|
||||||
@@ -71,14 +74,41 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type nftablesManager struct {
|
type nftablesManager struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
conn *nftables.Conn
|
conn *nftables.Conn
|
||||||
tableIPv4 *nftables.Table
|
tableIPv4 *nftables.Table
|
||||||
tableIPv6 *nftables.Table
|
tableIPv6 *nftables.Table
|
||||||
chains map[string]map[string]*nftables.Chain
|
chains map[string]map[string]*nftables.Chain
|
||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
mux sync.Mutex
|
filterTable *nftables.Table
|
||||||
|
defaultForwardRules []*nftables.Rule
|
||||||
|
mux sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
|
mgr := &nftablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
chains: make(map[string]map[string]*nftables.Chain),
|
||||||
|
rules: make(map[string]*nftables.Rule),
|
||||||
|
defaultForwardRules: make([]*nftables.Rule, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mgr.isSupported()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = mgr.readFilterTable()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing nftables rules from the system
|
// CleanRoutingRules cleans existing nftables rules from the system
|
||||||
@@ -90,6 +120,13 @@ func (n *nftablesManager) CleanRoutingRules() {
|
|||||||
n.conn.FlushTable(n.tableIPv6)
|
n.conn.FlushTable(n.tableIPv6)
|
||||||
n.conn.FlushTable(n.tableIPv4)
|
n.conn.FlushTable(n.tableIPv4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n.defaultForwardRules[0] != nil {
|
||||||
|
err := n.eraseDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete forward rule: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,6 +259,112 @@ func (n *nftablesManager) refreshRulesMap() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) readFilterTable() error {
|
||||||
|
tables, err := n.conn.ListTables()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.Name == "filter" {
|
||||||
|
n.filterTable = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
||||||
|
if n.defaultForwardRules[0] == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := n.refreshDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, r := range n.defaultForwardRules {
|
||||||
|
err = n.conn.DelRule(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete forward rule (%d): %s", i, err)
|
||||||
|
}
|
||||||
|
n.defaultForwardRules[i] = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) refreshDefaultForwardRule() error {
|
||||||
|
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to list rules in forward chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for i, r := range n.defaultForwardRules {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if string(rule.UserData) == string(r.UserData) {
|
||||||
|
n.defaultForwardRules[i] = rule
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("unable to find forward accept rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
|
||||||
|
src := generateCIDRMatcherExpressions("source", sourceNetwork)
|
||||||
|
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
|
||||||
|
|
||||||
|
var exprs []expr.Any
|
||||||
|
exprs = append(src, append(dst, &expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
})...)
|
||||||
|
|
||||||
|
r := &nftables.Rule{
|
||||||
|
Table: n.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: n.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||||
|
}
|
||||||
|
|
||||||
|
n.defaultForwardRules[0] = n.conn.AddRule(r)
|
||||||
|
|
||||||
|
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
|
||||||
|
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
|
||||||
|
|
||||||
|
exprs = append(src, append(dst, &expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
})...)
|
||||||
|
|
||||||
|
r = &nftables.Rule{
|
||||||
|
Table: n.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: n.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||||
|
}
|
||||||
|
|
||||||
|
n.defaultForwardRules[1] = n.conn.AddRule(r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
||||||
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
||||||
_, foundIPv4 := n.rules[ipv4Forwarding]
|
_, foundIPv4 := n.rules[ipv4Forwarding]
|
||||||
@@ -275,6 +418,14 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
|
||||||
|
err = n.acceptForwardRule(pair.source)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to create default forward rule: %s", err)
|
||||||
|
}
|
||||||
|
log.Debugf("default accept forward rule added")
|
||||||
|
}
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
||||||
@@ -355,6 +506,13 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
|
||||||
|
err := n.eraseDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delte default fwd rule: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
||||||
@@ -386,6 +544,14 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) isSupported() error {
|
||||||
|
_, err := n.conn.ListChains()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables is not supported: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// getPayloadDirectives get expression directives based on ip version and direction
|
// getPayloadDirectives get expression directives based on ip version and direction
|
||||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -14,21 +14,16 @@ import (
|
|||||||
|
|
||||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
manager, err := newNFTablesManager(context.TODO())
|
||||||
|
if err != nil {
|
||||||
manager := &nftablesManager{
|
t.Fatalf("failed to create nftables manager: %s", err)
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
err = manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||||
@@ -134,21 +129,16 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range insertRuleTestCases {
|
for _, testCase := range insertRuleTestCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
manager, err := newNFTablesManager(context.TODO())
|
||||||
|
if err != nil {
|
||||||
manager := &nftablesManager{
|
t.Fatalf("failed to create nftables manager: %s", err)
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
err = manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||||
@@ -239,21 +229,16 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range removeRuleTestCases {
|
for _, testCase := range removeRuleTestCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
manager, err := newNFTablesManager(context.TODO())
|
||||||
|
if err != nil {
|
||||||
manager := &nftablesManager{
|
t.Fatalf("failed to create nftables manager: %s", err)
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
err = manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
table := manager.tableIPv4
|
table := manager.tableIPv4
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func (s *Server) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil {
|
if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
|
||||||
log.Errorf("init connections: %v", err)
|
log.Errorf("init connections: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -391,7 +391,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil {
|
if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
|
||||||
log.Errorf("run client connection: %v", err)
|
log.Errorf("run client connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,6 @@ package ssh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/creack/pty"
|
|
||||||
"github.com/gliderlabs/ssh"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@@ -13,11 +10,22 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/creack/pty"
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||||
const DefaultSSHPort = 44338
|
const DefaultSSHPort = 44338
|
||||||
|
|
||||||
|
// TerminalTimeout is the timeout for terminal session to be ready
|
||||||
|
const TerminalTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// TerminalBackoffDelay is the delay between terminal session readiness checks
|
||||||
|
const TerminalBackoffDelay = 500 * time.Millisecond
|
||||||
|
|
||||||
// DefaultSSHServer is a function that creates DefaultServer
|
// DefaultSSHServer is a function that creates DefaultServer
|
||||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||||
return newDefaultServer(hostKeyPEM, addr)
|
return newDefaultServer(hostKeyPEM, addr)
|
||||||
@@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
|
||||||
|
|
||||||
localUser, err := userNameLookup(session.User())
|
localUser, err := userNameLookup(session.User())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||||
@@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("Login command: %s", cmd.String())
|
||||||
file, err := pty.Start(cmd)
|
file, err := pty.Start(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed starting SSH server %v", err)
|
log.Errorf("failed starting SSH server %v", err)
|
||||||
@@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.Debugf("SSH session ended")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||||
@@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
|||||||
// stdin
|
// stdin
|
||||||
_, err := io.Copy(file, session)
|
_, err := io.Copy(file, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = session.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
|
||||||
// stdout
|
timer := time.NewTimer(TerminalTimeout)
|
||||||
_, err := io.Copy(session, file)
|
for {
|
||||||
if err != nil {
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
|
||||||
|
_ = session.Exit(1)
|
||||||
return
|
return
|
||||||
|
default:
|
||||||
|
// stdout
|
||||||
|
writtenBytes, err := io.Copy(session, file)
|
||||||
|
if err != nil && writtenBytes != 0 {
|
||||||
|
_ = session.Exit(0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(TerminalBackoffDelay)
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts SSH server. Blocking
|
// Start starts SSH server. Blocking
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -48,6 +48,7 @@ require (
|
|||||||
github.com/mdlayher/socket v0.4.0
|
github.com/mdlayher/socket v0.4.0
|
||||||
github.com/miekg/dns v1.1.43
|
github.com/miekg/dns v1.1.43
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
github.com/pion/logging v0.2.2
|
github.com/pion/logging v0.2.2
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -485,6 +485,8 @@ github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8m
|
|||||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
||||||
|
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
|
||||||
|
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=
|
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v2"
|
"github.com/pion/transport/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
@@ -34,7 +33,6 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
|
|||||||
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
|
||||||
return w.tun.Create(mIFaceArgs)
|
return w.tun.Create(mIFaceArgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v2"
|
"github.com/pion/transport/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
@@ -38,6 +37,5 @@ func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
|||||||
func (w *WGIface) Create() error {
|
func (w *WGIface) Create() error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
|
||||||
return w.tun.Create()
|
return w.tun.Create()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
|
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
|
||||||
|
log.Info("create tun interface")
|
||||||
var err error
|
var err error
|
||||||
routesString := t.routesToString(mIFaceArgs.Routes)
|
routesString := t.routesToString(mIFaceArgs.Routes)
|
||||||
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)
|
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ import (
|
|||||||
|
|
||||||
func (c *tunDevice) Create() error {
|
func (c *tunDevice) Create() error {
|
||||||
if WireGuardModuleIsLoaded() {
|
if WireGuardModuleIsLoaded() {
|
||||||
log.Info("using kernel WireGuard")
|
log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
|
||||||
return c.createWithKernel()
|
return c.createWithKernel()
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tunModuleIsLoaded() {
|
if !tunModuleIsLoaded() {
|
||||||
return fmt.Errorf("couldn't check or load tun module")
|
return fmt.Errorf("couldn't check or load tun module")
|
||||||
}
|
}
|
||||||
log.Info("using userspace WireGuard")
|
log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
|
||||||
var err error
|
var err error
|
||||||
c.netInterface, err = c.createWithUserspace()
|
c.netInterface, err = c.createWithUserspace()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
|
|||||||
c.wrapper = newDeviceWrapper(tunIface)
|
c.wrapper = newDeviceWrapper(tunIface)
|
||||||
|
|
||||||
// We need to create a wireguard-go device and listen to configuration requests
|
// We need to create a wireguard-go device and listen to configuration requests
|
||||||
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
tunDev := device.NewDevice(c.wrapper, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
||||||
err = tunDev.Up()
|
err = tunDev.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
_ = tunIface.Close()
|
||||||
|
|||||||
@@ -97,16 +97,9 @@ curl "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" -q -o openid-configuration.js
|
|||||||
|
|
||||||
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' openid-configuration.json)
|
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' openid-configuration.json)
|
||||||
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' openid-configuration.json)
|
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' openid-configuration.json)
|
||||||
export NETBIRD_AUTH_SUPPORTED_SCOPES=$(jq -r '.scopes_supported | join(" ")' openid-configuration.json)
|
|
||||||
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' openid-configuration.json)
|
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' openid-configuration.json)
|
||||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' openid-configuration.json)
|
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' openid-configuration.json)
|
||||||
|
|
||||||
if [ "$NETBIRD_USE_AUTH0" == "true" ]; then
|
|
||||||
export NETBIRD_AUTH_SUPPORTED_SCOPES="openid profile email offline_access api email_verified"
|
|
||||||
else
|
|
||||||
export NETBIRD_AUTH_SUPPORTED_SCOPES="openid profile email offline_access api"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
||||||
# user enabled Device Authorization Grant feature
|
# user enabled Device Authorization Grant feature
|
||||||
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
||||||
|
|||||||
@@ -11,7 +11,10 @@ NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
|
|||||||
NETBIRD_AUTH_AUDIENCE=""
|
NETBIRD_AUTH_AUDIENCE=""
|
||||||
# e.g. netbird-client
|
# e.g. netbird-client
|
||||||
NETBIRD_AUTH_CLIENT_ID=""
|
NETBIRD_AUTH_CLIENT_ID=""
|
||||||
NETBIRD_AUTH_CLIENT_SECRET=""
|
# indicates the scopes that will be requested to the IDP
|
||||||
|
NETBIRD_AUTH_SUPPORTED_SCOPES=""
|
||||||
|
# NETBIRD_AUTH_CLIENT_SECRET is required only by Google workspace.
|
||||||
|
# NETBIRD_AUTH_CLIENT_SECRET=""
|
||||||
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
||||||
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
||||||
# indicates whether to use Auth0 or not: true or false
|
# indicates whether to use Auth0 or not: true or false
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ NETBIRD_DOMAIN=$CI_NETBIRD_DOMAIN
|
|||||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://example.eu.auth0.com/.well-known/openid-configuration"
|
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://example.eu.auth0.com/.well-known/openid-configuration"
|
||||||
# e.g. netbird-client
|
# e.g. netbird-client
|
||||||
NETBIRD_AUTH_CLIENT_ID=$CI_NETBIRD_AUTH_CLIENT_ID
|
NETBIRD_AUTH_CLIENT_ID=$CI_NETBIRD_AUTH_CLIENT_ID
|
||||||
|
NETBIRD_AUTH_SUPPORTED_SCOPES=$CI_NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||||
NETBIRD_AUTH_CLIENT_SECRET=$CI_NETBIRD_AUTH_CLIENT_SECRET
|
NETBIRD_AUTH_CLIENT_SECRET=$CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||||
# indicates whether to use Auth0 or not: true or false
|
# indicates whether to use Auth0 or not: true or false
|
||||||
NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
|
NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
|
||||||
|
|||||||
@@ -218,7 +218,11 @@ var (
|
|||||||
if !disableMetrics {
|
if !disableMetrics {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager)
|
idpManager := "disabled"
|
||||||
|
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
|
||||||
|
idpManager = config.IdpManagerConfig.ManagerType
|
||||||
|
}
|
||||||
|
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
|
||||||
go metricsWorker.Run()
|
go metricsWorker.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ const (
|
|||||||
PublicCategory = "public"
|
PublicCategory = "public"
|
||||||
PrivateCategory = "private"
|
PrivateCategory = "private"
|
||||||
UnknownCategory = "unknown"
|
UnknownCategory = "unknown"
|
||||||
|
GroupIssuedAPI = "api"
|
||||||
|
GroupIssuedJWT = "jwt"
|
||||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||||
@@ -51,6 +53,7 @@ type AccountManager interface {
|
|||||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||||
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||||
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
||||||
|
InviteUser(accountID string, initiatorUserID string, targetUserID string) error
|
||||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||||
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||||
@@ -78,7 +81,7 @@ type AccountManager interface {
|
|||||||
GetGroup(accountId, groupID string) (*Group, error)
|
GetGroup(accountId, groupID string) (*Group, error)
|
||||||
SaveGroup(accountID, userID string, group *Group) error
|
SaveGroup(accountID, userID string, group *Group) error
|
||||||
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
||||||
DeleteGroup(accountId, groupID string) error
|
DeleteGroup(accountId, userId, groupID string) error
|
||||||
ListGroups(accountId string) ([]*Group, error)
|
ListGroups(accountId string) ([]*Group, error)
|
||||||
GroupAddPeer(accountId, groupID, peerID string) error
|
GroupAddPeer(accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(accountId, groupID, peerKey string) error
|
GroupDeletePeer(accountId, groupID, peerKey string) error
|
||||||
@@ -139,6 +142,13 @@ type Settings struct {
|
|||||||
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
||||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||||
PeerLoginExpiration time.Duration
|
PeerLoginExpiration time.Duration
|
||||||
|
|
||||||
|
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
|
||||||
|
// and add it to account groups.
|
||||||
|
JWTGroupsEnabled bool
|
||||||
|
|
||||||
|
// JWTGroupsClaimName from which we extract groups name to add it to account groups
|
||||||
|
JWTGroupsClaimName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy copies the Settings struct
|
// Copy copies the Settings struct
|
||||||
@@ -146,6 +156,8 @@ func (s *Settings) Copy() *Settings {
|
|||||||
return &Settings{
|
return &Settings{
|
||||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||||
PeerLoginExpiration: s.PeerLoginExpiration,
|
PeerLoginExpiration: s.PeerLoginExpiration,
|
||||||
|
JWTGroupsEnabled: s.JWTGroupsEnabled,
|
||||||
|
JWTGroupsClaimName: s.JWTGroupsClaimName,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -612,6 +624,28 @@ func (a *Account) GetPeer(peerID string) *Peer {
|
|||||||
return a.Peers[peerID]
|
return a.Peers[peerID]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddJWTGroups to existed groups if they does not exists
|
||||||
|
func (a *Account) AddJWTGroups(groups []string) (int, error) {
|
||||||
|
existedGroups := make(map[string]*Group)
|
||||||
|
for _, g := range a.Groups {
|
||||||
|
existedGroups[g.Name] = g
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
for _, name := range groups {
|
||||||
|
if _, ok := existedGroups[name]; !ok {
|
||||||
|
id := xid.New().String()
|
||||||
|
a.Groups[id] = &Group{
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
Issued: GroupIssuedJWT,
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||||
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
||||||
@@ -1241,6 +1275,38 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.Settings.JWTGroupsEnabled {
|
||||||
|
if account.Settings.JWTGroupsClaimName == "" {
|
||||||
|
log.Errorf("JWT groups are enabled but no claim name is set")
|
||||||
|
return account, user, nil
|
||||||
|
}
|
||||||
|
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||||
|
if slice, ok := claim.([]interface{}); ok {
|
||||||
|
var groups []string
|
||||||
|
for _, item := range slice {
|
||||||
|
if g, ok := item.(string); ok {
|
||||||
|
groups = append(groups, g)
|
||||||
|
} else {
|
||||||
|
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n, err := account.AddJWTGroups(groups)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add JWT groups: %v", err)
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
if err := am.Store.SaveAccount(account); err != nil {
|
||||||
|
log.Errorf("failed to save account: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return account, user, nil
|
return account, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1344,8 +1410,9 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
|||||||
func addAllGroup(account *Account) error {
|
func addAllGroup(account *Account) error {
|
||||||
if len(account.Groups) == 0 {
|
if len(account.Groups) == 0 {
|
||||||
allGroup := &Group{
|
allGroup := &Group{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
Name: "All",
|
Name: "All",
|
||||||
|
Issued: GroupIssuedAPI,
|
||||||
}
|
}
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||||
@@ -1373,33 +1440,28 @@ func addAllGroup(account *Account) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||||
func newAccountWithId(accountId, userId, domain string) *Account {
|
func newAccountWithId(accountID, userID, domain string) *Account {
|
||||||
log.Debugf("creating new account")
|
log.Debugf("creating new account")
|
||||||
|
|
||||||
setupKeys := make(map[string]*SetupKey)
|
|
||||||
defaultKey := GenerateDefaultSetupKey()
|
|
||||||
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{},
|
|
||||||
SetupKeyUnlimitedUsage)
|
|
||||||
setupKeys[defaultKey.Key] = defaultKey
|
|
||||||
setupKeys[oneOffKey.Key] = oneOffKey
|
|
||||||
network := NewNetwork()
|
network := NewNetwork()
|
||||||
peers := make(map[string]*Peer)
|
peers := make(map[string]*Peer)
|
||||||
users := make(map[string]*User)
|
users := make(map[string]*User)
|
||||||
routes := make(map[string]*route.Route)
|
routes := make(map[string]*route.Route)
|
||||||
|
setupKeys := map[string]*SetupKey{}
|
||||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||||
users[userId] = NewAdminUser(userId)
|
users[userID] = NewAdminUser(userID)
|
||||||
dnsSettings := &DNSSettings{
|
dnsSettings := &DNSSettings{
|
||||||
DisabledManagementGroups: make([]string, 0),
|
DisabledManagementGroups: make([]string, 0),
|
||||||
}
|
}
|
||||||
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
|
log.Debugf("created new account %s", accountID)
|
||||||
|
|
||||||
acc := &Account{
|
acc := &Account{
|
||||||
Id: accountId,
|
Id: accountID,
|
||||||
SetupKeys: setupKeys,
|
SetupKeys: setupKeys,
|
||||||
Network: network,
|
Network: network,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
Users: users,
|
Users: users,
|
||||||
CreatedBy: userId,
|
CreatedBy: userID,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
NameServerGroups: nameServersGroups,
|
NameServerGroups: nameServersGroups,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -53,7 +54,7 @@ func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy
|
|||||||
t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers))
|
t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(account.SetupKeys) != 2 {
|
if len(account.SetupKeys) != 0 {
|
||||||
t.Errorf("expected account to have len(SetupKeys) = %v, got %v", 2, len(account.SetupKeys))
|
t.Errorf("expected account to have len(SetupKeys) = %v, got %v", 2, len(account.SetupKeys))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -460,6 +461,75 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||||
|
userId := "user-id"
|
||||||
|
domain := "test.domain"
|
||||||
|
|
||||||
|
initAccount := newAccountWithId("", userId, domain)
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
|
accountID := initAccount.Id
|
||||||
|
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain)
|
||||||
|
require.NoError(t, err, "create init user failed")
|
||||||
|
// as initAccount was created without account id we have to take the id after account initialization
|
||||||
|
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
|
||||||
|
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||||
|
initAccount.Id = acc.Id
|
||||||
|
|
||||||
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
|
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||||
|
Domain: domain,
|
||||||
|
UserId: userId,
|
||||||
|
DomainCategory: "test-category",
|
||||||
|
Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||||
|
account, _, err := manager.GetAccountFromToken(claims)
|
||||||
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||||
|
initAccount.Settings.JWTGroupsEnabled = true
|
||||||
|
err := manager.Store.SaveAccount(initAccount)
|
||||||
|
require.NoError(t, err, "save account failed")
|
||||||
|
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
|
||||||
|
|
||||||
|
account, _, err := manager.GetAccountFromToken(claims)
|
||||||
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||||
|
initAccount.Settings.JWTGroupsEnabled = true
|
||||||
|
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||||
|
err := manager.Store.SaveAccount(initAccount)
|
||||||
|
require.NoError(t, err, "save account failed")
|
||||||
|
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
|
||||||
|
|
||||||
|
account, _, err := manager.GetAccountFromToken(claims)
|
||||||
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||||
|
|
||||||
|
groupsByNames := map[string]*Group{}
|
||||||
|
for _, g := range account.Groups {
|
||||||
|
groupsByNames[g.Name] = g
|
||||||
|
}
|
||||||
|
|
||||||
|
g1, ok := groupsByNames["group1"]
|
||||||
|
require.True(t, ok, "group1 should be added to the account")
|
||||||
|
require.Equal(t, g1.Name, "group1", "group1 name should match")
|
||||||
|
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match")
|
||||||
|
|
||||||
|
g2, ok := groupsByNames["group2"]
|
||||||
|
require.True(t, ok, "group2 should be added to the account")
|
||||||
|
require.Equal(t, g2.Name, "group2", "group2 name should match")
|
||||||
|
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId("account_id", "testuser", "")
|
account := newAccountWithId("account_id", "testuser", "")
|
||||||
@@ -704,20 +774,21 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := createAccount(manager, "test_account", "account_creator", "netbird.cloud")
|
userID := "account_creator"
|
||||||
|
account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
serial := account.Network.CurrentSerial() // should be 0
|
serial := account.Network.CurrentSerial() // should be 0
|
||||||
|
|
||||||
var setupKey *SetupKey
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||||
for _, key := range account.SetupKeys {
|
if err != nil {
|
||||||
setupKey = key
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if setupKey == nil {
|
if err != nil {
|
||||||
t.Errorf("expecting account to have a default setup key")
|
t.Fatal("error creating setup key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -858,16 +929,13 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var setupKey *SetupKey
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||||
for _, key := range account.SetupKeys {
|
if err != nil {
|
||||||
setupKey = key
|
return
|
||||||
if setupKey.Type == SetupKeyReusable {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if setupKey == nil {
|
if err != nil {
|
||||||
t.Errorf("expecting account to have a default setup key")
|
t.Fatal("error creating setup key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1021,7 +1089,10 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := manager.DeleteGroup(account.Id, group.ID); err != nil {
|
// clean policy is pre requirement for delete group
|
||||||
|
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
|
||||||
|
|
||||||
|
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
|
||||||
t.Errorf("delete group: %v", err)
|
t.Errorf("delete group: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1042,9 +1113,14 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var setupKey *SetupKey
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||||
for _, key := range account.SetupKeys {
|
if err != nil {
|
||||||
setupKey = key
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating setup key")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ const (
|
|||||||
UserBlocked
|
UserBlocked
|
||||||
// UserUnblocked indicates that a user unblocked another user
|
// UserUnblocked indicates that a user unblocked another user
|
||||||
UserUnblocked
|
UserUnblocked
|
||||||
|
// GroupDeleted indicates that a user deleted group
|
||||||
|
GroupDeleted
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -192,6 +194,8 @@ const (
|
|||||||
UserBlockedMessage string = "User blocked"
|
UserBlockedMessage string = "User blocked"
|
||||||
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
||||||
UserUnblockedMessage string = "User unblocked"
|
UserUnblockedMessage string = "User unblocked"
|
||||||
|
// GroupDeletedMessage is a human-readable text message of the GroupDeleted activity
|
||||||
|
GroupDeletedMessage string = "Group deleted"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Activity that triggered an Event
|
// Activity that triggered an Event
|
||||||
@@ -294,6 +298,8 @@ func (a Activity) Message() string {
|
|||||||
return UserBlockedMessage
|
return UserBlockedMessage
|
||||||
case UserUnblocked:
|
case UserUnblocked:
|
||||||
return UserUnblockedMessage
|
return UserUnblockedMessage
|
||||||
|
case GroupDeleted:
|
||||||
|
return GroupDeletedMessage
|
||||||
default:
|
default:
|
||||||
return "UNKNOWN_ACTIVITY"
|
return "UNKNOWN_ACTIVITY"
|
||||||
}
|
}
|
||||||
@@ -342,6 +348,8 @@ func (a Activity) StringCode() string {
|
|||||||
return "group.add"
|
return "group.add"
|
||||||
case GroupUpdated:
|
case GroupUpdated:
|
||||||
return "group.update"
|
return "group.update"
|
||||||
|
case GroupDeleted:
|
||||||
|
return "group.delete"
|
||||||
case GroupRemovedFromPeer:
|
case GroupRemovedFromPeer:
|
||||||
return "peer.group.delete"
|
return "peer.group.delete"
|
||||||
case GroupAddedToPeer:
|
case GroupAddedToPeer:
|
||||||
|
|||||||
@@ -2,13 +2,15 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultTTL = 300
|
const defaultTTL = 300
|
||||||
@@ -199,8 +201,10 @@ func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
|||||||
for _, gID := range nsGroup.Groups {
|
for _, gID := range nsGroup.Groups {
|
||||||
_, found := groupList[gID]
|
_, found := groupList[gID]
|
||||||
if found {
|
if found {
|
||||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||||
break
|
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,6 +212,16 @@ func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
|||||||
return peerNSGroups
|
return peerNSGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||||
|
func peerIsNameserver(peer *Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
|
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
label, err := getPeerHostLabel(peer.Name, peerLabels)
|
label, err := getPeerHostLabel(peer.Name, peerLabels)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
@@ -17,6 +19,7 @@ const (
|
|||||||
dnsAccountID = "testingAcc"
|
dnsAccountID = "testingAcc"
|
||||||
dnsAdminUserID = "testingAdminUser"
|
dnsAdminUserID = "testingAdminUser"
|
||||||
dnsRegularUserID = "testingRegularUser"
|
dnsRegularUserID = "testingRegularUser"
|
||||||
|
dnsNSGroup1 = "ns1"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetDNSSettings(t *testing.T) {
|
func TestGetDNSSettings(t *testing.T) {
|
||||||
@@ -163,6 +166,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
||||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
||||||
|
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
||||||
|
|
||||||
dnsSettings := account.DNSSettings.Copy()
|
dnsSettings := account.DNSSettings.Copy()
|
||||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||||
@@ -174,11 +178,11 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
||||||
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
||||||
|
|
||||||
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
||||||
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
||||||
|
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||||
@@ -246,7 +250,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err = am.AddPeer("", dnsAdminUserID, peer1)
|
savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -284,6 +288,24 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
|||||||
account.Groups[newGroup1.ID] = newGroup1
|
account.Groups[newGroup1.ID] = newGroup1
|
||||||
account.Groups[newGroup2.ID] = newGroup2
|
account.Groups[newGroup2.ID] = newGroup2
|
||||||
|
|
||||||
|
allGroup, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{
|
||||||
|
ID: dnsNSGroup1,
|
||||||
|
Name: "ns-group-1",
|
||||||
|
NameServers: []dns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr(savedPeer1.IP.String()),
|
||||||
|
NSType: dns.UDPNameServerType,
|
||||||
|
Port: dns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
Primary: true,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{allGroup.ID},
|
||||||
|
}
|
||||||
|
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -157,6 +157,14 @@ func restore(file string) (*FileStore, error) {
|
|||||||
addPeerLabelsToAccount(account, existingLabels)
|
addPeerLabelsToAccount(account, existingLabels)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: delete this block after migration
|
||||||
|
// Set API as issuer for groups which has not this field
|
||||||
|
for _, group := range account.Groups {
|
||||||
|
if group.Issued == "" {
|
||||||
|
group.Issued = GroupIssuedAPI
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
allGroup, err := account.GetGroupAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||||
@@ -281,6 +289,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
|||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
if account.Id == "" {
|
||||||
|
return status.Errorf(status.InvalidArgument, "account id should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
accountCopy := account.Copy()
|
accountCopy := account.Copy()
|
||||||
|
|
||||||
s.Accounts[accountCopy.Id] = accountCopy
|
s.Accounts[accountCopy.Id] = accountCopy
|
||||||
@@ -326,7 +338,7 @@ func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
|||||||
|
|
||||||
delete(s.HashedPAT2TokenID, hashedToken)
|
delete(s.HashedPAT2TokenID, hashedToken)
|
||||||
|
|
||||||
return s.persist(s.storeFile)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
|
// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
|
||||||
@@ -336,7 +348,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
|||||||
|
|
||||||
delete(s.TokenID2UserID, tokenID)
|
delete(s.TokenID2UserID, tokenID)
|
||||||
|
|
||||||
return s.persist(s.storeFile)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByPrivateDomain returns account by private domain
|
// GetAccountByPrivateDomain returns account by private domain
|
||||||
|
|||||||
@@ -262,6 +262,7 @@ func TestRestore(t *testing.T) {
|
|||||||
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: outdated, delete this
|
||||||
func TestRestorePolicies_Migration(t *testing.T) {
|
func TestRestorePolicies_Migration(t *testing.T) {
|
||||||
storeDir := t.TempDir()
|
storeDir := t.TempDir()
|
||||||
|
|
||||||
@@ -296,6 +297,40 @@ func TestRestorePolicies_Migration(t *testing.T) {
|
|||||||
"failed to restore a FileStore file - missing Account Policies Sources")
|
"failed to restore a FileStore file - missing Account Policies Sources")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRestoreGroups_Migration(t *testing.T) {
|
||||||
|
storeDir := t.TempDir()
|
||||||
|
|
||||||
|
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := NewFileStore(storeDir, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create default group
|
||||||
|
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||||
|
account.Groups = map[string]*Group{
|
||||||
|
"cfefqs706sqkneg59g3g": {
|
||||||
|
ID: "cfefqs706sqkneg59g3g",
|
||||||
|
Name: "All",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = store.SaveAccount(account)
|
||||||
|
require.NoError(t, err, "failed to save account")
|
||||||
|
|
||||||
|
// restore account with default group with empty Issue field
|
||||||
|
if store, err = NewFileStore(storeDir, nil); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||||
|
|
||||||
|
require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
|
||||||
|
require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetAccountByPrivateDomain(t *testing.T) {
|
func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||||
storeDir := t.TempDir()
|
storeDir := t.TempDir()
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,23 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type GroupLinkError struct {
|
||||||
|
Resource string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *GroupLinkError) Error() string {
|
||||||
|
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
||||||
|
}
|
||||||
|
|
||||||
// Group of the peers for ACL
|
// Group of the peers for ACL
|
||||||
type Group struct {
|
type Group struct {
|
||||||
// ID of the group
|
// ID of the group
|
||||||
@@ -14,6 +26,9 @@ type Group struct {
|
|||||||
// Name visible in the UI
|
// Name visible in the UI
|
||||||
Name string
|
Name string
|
||||||
|
|
||||||
|
// Issued of the group
|
||||||
|
Issued string
|
||||||
|
|
||||||
// Peers list of the group
|
// Peers list of the group
|
||||||
Peers []string
|
Peers []string
|
||||||
}
|
}
|
||||||
@@ -45,9 +60,10 @@ func (g *Group) EventMeta() map[string]any {
|
|||||||
|
|
||||||
func (g *Group) Copy() *Group {
|
func (g *Group) Copy() *Group {
|
||||||
return &Group{
|
return &Group{
|
||||||
ID: g.ID,
|
ID: g.ID,
|
||||||
Name: g.Name,
|
Name: g.Name,
|
||||||
Peers: g.Peers[:],
|
Issued: g.Issued,
|
||||||
|
Peers: g.Peers[:],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,15 +215,80 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup object of the peers
|
// DeleteGroup object of the peers
|
||||||
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountId)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
g, ok := account.Groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check route links
|
||||||
|
for _, r := range account.Routes {
|
||||||
|
for _, g := range r.Groups {
|
||||||
|
if g == groupID {
|
||||||
|
return &GroupLinkError{"route", r.NetID}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check DNS links
|
||||||
|
for _, dns := range account.NameServerGroups {
|
||||||
|
for _, g := range dns.Groups {
|
||||||
|
if g == groupID {
|
||||||
|
return &GroupLinkError{"name server groups", dns.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check ACL links
|
||||||
|
for _, policy := range account.Policies {
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
for _, src := range rule.Sources {
|
||||||
|
if src == groupID {
|
||||||
|
return &GroupLinkError{"policy", policy.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dst := range rule.Destinations {
|
||||||
|
if dst == groupID {
|
||||||
|
return &GroupLinkError{"policy", policy.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check setup key links
|
||||||
|
for _, setupKey := range account.SetupKeys {
|
||||||
|
for _, grp := range setupKey.AutoGroups {
|
||||||
|
if grp == groupID {
|
||||||
|
return &GroupLinkError{"setup key", setupKey.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check user links
|
||||||
|
for _, user := range account.Users {
|
||||||
|
for _, grp := range user.AutoGroups {
|
||||||
|
if grp == groupID {
|
||||||
|
return &GroupLinkError{"user", user.Id}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check DisabledManagementGroups
|
||||||
|
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
|
||||||
|
if disabledMgmGrp == groupID {
|
||||||
|
return &GroupLinkError{"disabled DNS management groups", g.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
delete(account.Groups, groupID)
|
delete(account.Groups, groupID)
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
@@ -215,6 +296,8 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
return am.updateAccountPeers(account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
164
management/server/group_test.go
Normal file
164
management/server/group_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
groupAdminUserID = "testingAdminUser"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||||
|
am, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to create account manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := initTestGroupAccount(am)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to init testing account")
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
groupID string
|
||||||
|
expectedReason string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"route",
|
||||||
|
"grp-for-route",
|
||||||
|
"route",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name server groups",
|
||||||
|
"grp-for-name-server-grp",
|
||||||
|
"name server groups",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"policy",
|
||||||
|
"grp-for-policies",
|
||||||
|
"policy",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup keys",
|
||||||
|
"grp-for-keys",
|
||||||
|
"setup key",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"users",
|
||||||
|
"grp-for-users",
|
||||||
|
"user",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
err = am.DeleteGroup(account.Id, "", testCase.groupID)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("delete %s group successfully", testCase.groupID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gErr, ok := err.(*GroupLinkError)
|
||||||
|
if !ok {
|
||||||
|
t.Error("invalid error type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gErr.Resource != testCase.expectedReason {
|
||||||
|
t.Errorf("invalid error case: %s, expected: %s", gErr.Resource, testCase.expectedReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||||
|
accountID := "testingAcc"
|
||||||
|
domain := "example.com"
|
||||||
|
|
||||||
|
groupForRoute := &Group{
|
||||||
|
"grp-for-route",
|
||||||
|
"Group for route",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupForNameServerGroups := &Group{
|
||||||
|
"grp-for-name-server-grp",
|
||||||
|
"Group for name server groups",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupForPolicies := &Group{
|
||||||
|
"grp-for-policies",
|
||||||
|
"Group for policies",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupForSetupKeys := &Group{
|
||||||
|
"grp-for-keys",
|
||||||
|
"Group for setup keys",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupForUsers := &Group{
|
||||||
|
"grp-for-users",
|
||||||
|
"Group for users",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
routeResource := &route.Route{
|
||||||
|
ID: "example route",
|
||||||
|
Groups: []string{groupForRoute.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
nameServerGroup := &nbdns.NameServerGroup{
|
||||||
|
ID: "example name server group",
|
||||||
|
Groups: []string{groupForNameServerGroups.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := &Policy{
|
||||||
|
ID: "example policy",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "example policy rule",
|
||||||
|
Destinations: []string{groupForPolicies.ID},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey := &SetupKey{
|
||||||
|
Id: "example setup key",
|
||||||
|
AutoGroups: []string{groupForSetupKeys.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &User{
|
||||||
|
Id: "example user",
|
||||||
|
AutoGroups: []string{groupForUsers.ID},
|
||||||
|
}
|
||||||
|
account := newAccountWithId(accountID, groupAdminUserID, domain)
|
||||||
|
account.Routes[routeResource.ID] = routeResource
|
||||||
|
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||||
|
account.Policies = append(account.Policies, policy)
|
||||||
|
account.SetupKeys[setupKey.Id] = setupKey
|
||||||
|
account.Users[user.Id] = user
|
||||||
|
|
||||||
|
err := am.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
|
||||||
|
|
||||||
|
return am.Store.GetAccount(account.Id)
|
||||||
|
}
|
||||||
@@ -72,10 +72,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, &server.Settings{
|
settings := &server.Settings{
|
||||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if req.Settings.JwtGroupsEnabled != nil {
|
||||||
|
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
||||||
|
}
|
||||||
|
if req.Settings.JwtGroupsClaimName != nil {
|
||||||
|
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
@@ -93,6 +102,8 @@ func toAccountResponse(account *server.Account) *api.Account {
|
|||||||
Settings: api.AccountSettings{
|
Settings: api.AccountSettings{
|
||||||
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
||||||
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
||||||
|
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
|
||||||
|
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
accountID := "test_account"
|
accountID := "test_account"
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user")
|
||||||
|
|
||||||
|
sr := func(v string) *string { return &v }
|
||||||
|
br := func(v bool) *bool { return &v }
|
||||||
|
|
||||||
handler := initAccountsTestData(&server.Account{
|
handler := initAccountsTestData(&server.Account{
|
||||||
Id: accountID,
|
Id: accountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
@@ -91,6 +94,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedSettings: api.AccountSettings{
|
expectedSettings: api.AccountSettings{
|
||||||
PeerLoginExpiration: int(time.Hour.Seconds()),
|
PeerLoginExpiration: int(time.Hour.Seconds()),
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
|
JwtGroupsClaimName: sr(""),
|
||||||
|
JwtGroupsEnabled: br(false),
|
||||||
},
|
},
|
||||||
expectedArray: true,
|
expectedArray: true,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -105,6 +110,24 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedSettings: api.AccountSettings{
|
expectedSettings: api.AccountSettings{
|
||||||
PeerLoginExpiration: 15552000,
|
PeerLoginExpiration: 15552000,
|
||||||
PeerLoginExpirationEnabled: true,
|
PeerLoginExpirationEnabled: true,
|
||||||
|
JwtGroupsClaimName: sr(""),
|
||||||
|
JwtGroupsEnabled: br(false),
|
||||||
|
},
|
||||||
|
expectedArray: false,
|
||||||
|
expectedID: accountID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PutAccount OK wiht JWT",
|
||||||
|
expectedBody: true,
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/accounts/" + accountID,
|
||||||
|
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedSettings: api.AccountSettings{
|
||||||
|
PeerLoginExpiration: 15552000,
|
||||||
|
PeerLoginExpirationEnabled: false,
|
||||||
|
JwtGroupsClaimName: sr("roles"),
|
||||||
|
JwtGroupsEnabled: br(true),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ components:
|
|||||||
description: Period of time after which peer login expires (seconds).
|
description: Period of time after which peer login expires (seconds).
|
||||||
type: integer
|
type: integer
|
||||||
example: 43200
|
example: 43200
|
||||||
|
jwt_groups_enabled:
|
||||||
|
description: Allows extract groups from JWT claim and add it to account groups.
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
jwt_groups_claim_name:
|
||||||
|
description: Name of the claim from which we extract groups names to add it to account groups.
|
||||||
|
type: string
|
||||||
|
example: "roles"
|
||||||
required:
|
required:
|
||||||
- peer_login_expiration_enabled
|
- peer_login_expiration_enabled
|
||||||
- peer_login_expiration
|
- peer_login_expiration
|
||||||
@@ -462,6 +470,10 @@ components:
|
|||||||
description: Count of peers associated to the group
|
description: Count of peers associated to the group
|
||||||
type: integer
|
type: integer
|
||||||
example: 2
|
example: 2
|
||||||
|
issued:
|
||||||
|
description: How group was issued by API or from JWT token
|
||||||
|
type: string
|
||||||
|
example: api
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- name
|
- name
|
||||||
@@ -1262,6 +1274,33 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
/api/users/{userId}/invite:
|
||||||
|
post:
|
||||||
|
summary: Resend user invitation
|
||||||
|
description: Resend user invitation
|
||||||
|
tags: [ Users ]
|
||||||
|
security:
|
||||||
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: userId
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: The unique identifier of a user
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Invite status code
|
||||||
|
content: {}
|
||||||
|
'400':
|
||||||
|
"$ref": "#/components/responses/bad_request"
|
||||||
|
'401':
|
||||||
|
"$ref": "#/components/responses/requires_authentication"
|
||||||
|
'403':
|
||||||
|
"$ref": "#/components/responses/forbidden"
|
||||||
|
'500':
|
||||||
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/peers:
|
/api/peers:
|
||||||
get:
|
get:
|
||||||
summary: List all Peers
|
summary: List all Peers
|
||||||
|
|||||||
@@ -129,6 +129,12 @@ type AccountRequest struct {
|
|||||||
|
|
||||||
// AccountSettings defines model for AccountSettings.
|
// AccountSettings defines model for AccountSettings.
|
||||||
type AccountSettings struct {
|
type AccountSettings struct {
|
||||||
|
// JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups.
|
||||||
|
JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"`
|
||||||
|
|
||||||
|
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
||||||
|
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
||||||
|
|
||||||
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
||||||
PeerLoginExpiration int `json:"peer_login_expiration"`
|
PeerLoginExpiration int `json:"peer_login_expiration"`
|
||||||
|
|
||||||
@@ -174,6 +180,9 @@ type Group struct {
|
|||||||
// Id Group ID
|
// Id Group ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// Issued How group was issued by API or from JWT token
|
||||||
|
Issued *string `json:"issued,omitempty"`
|
||||||
|
|
||||||
// Name Group Name identifier
|
// Name Group Name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
@@ -189,6 +198,9 @@ type GroupMinimum struct {
|
|||||||
// Id Group ID
|
// Id Group ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// Issued How group was issued by API or from JWT token
|
||||||
|
Issued *string `json:"issued,omitempty"`
|
||||||
|
|
||||||
// Name Group Name identifier
|
// Name Group Name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok = account.Groups[groupID]
|
eg, ok := account.Groups[groupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||||
return
|
return
|
||||||
@@ -107,9 +107,10 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
peers = *req.Peers
|
peers = *req.Peers
|
||||||
}
|
}
|
||||||
group := server.Group{
|
group := server.Group{
|
||||||
ID: groupID,
|
ID: groupID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
|
Issued: eg.Issued,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
|
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
|
||||||
@@ -149,9 +150,10 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
peers = *req.Peers
|
peers = *req.Peers
|
||||||
}
|
}
|
||||||
group := server.Group{
|
group := server.Group{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
|
Issued: server.GroupIssuedAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
|
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
|
||||||
@@ -166,7 +168,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
// DeleteGroup handles group deletion request
|
// DeleteGroup handles group deletion request
|
||||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
@@ -190,8 +192,13 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteGroup(aID, groupID)
|
err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_, ok := err.(*server.GroupLinkError)
|
||||||
|
if ok {
|
||||||
|
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -237,6 +244,7 @@ func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
|
|||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
PeersCount: len(group.Peers),
|
PeersCount: len(group.Peers),
|
||||||
|
Issued: &group.Issued,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pid := range group.Peers {
|
for _, pid := range group.Peers {
|
||||||
|
|||||||
@@ -11,17 +11,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
||||||
|
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/magiconair/properties/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
var TestPeers = map[string]*server.Peer{
|
var TestPeers = map[string]*server.Peer{
|
||||||
@@ -42,9 +40,17 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
|||||||
if groupID != "idofthegroup" {
|
if groupID != "idofthegroup" {
|
||||||
return nil, status.Errorf(status.NotFound, "not found")
|
return nil, status.Errorf(status.NotFound, "not found")
|
||||||
}
|
}
|
||||||
|
if groupID == "id-jwt-group" {
|
||||||
|
return &server.Group{
|
||||||
|
ID: "id-jwt-group",
|
||||||
|
Name: "Default Group",
|
||||||
|
Issued: server.GroupIssuedJWT,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
return &server.Group{
|
return &server.Group{
|
||||||
ID: "idofthegroup",
|
ID: "idofthegroup",
|
||||||
Name: "Group",
|
Name: "Group",
|
||||||
|
Issued: server.GroupIssuedAPI,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
||||||
@@ -80,11 +86,24 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
|||||||
user.Id: user,
|
user.Id: user,
|
||||||
},
|
},
|
||||||
Groups: map[string]*server.Group{
|
Groups: map[string]*server.Group{
|
||||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
|
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT},
|
||||||
"id-all": {ID: "id-all", Name: "All"},
|
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI},
|
||||||
|
"id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI},
|
||||||
},
|
},
|
||||||
}, user, nil
|
}, user, nil
|
||||||
},
|
},
|
||||||
|
DeleteGroupFunc: func(accountID, userId, groupID string) error {
|
||||||
|
if groupID == "linked-grp" {
|
||||||
|
return &server.GroupLinkError{
|
||||||
|
Resource: "something",
|
||||||
|
Name: "linked-grp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if groupID == "invalid-grp" {
|
||||||
|
return fmt.Errorf("internal error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
@@ -169,6 +188,8 @@ func TestGetGroup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteGroup(t *testing.T) {
|
func TestWriteGroup(t *testing.T) {
|
||||||
|
groupIssuedAPI := "api"
|
||||||
|
groupIssuedJWT := "jwt"
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
@@ -187,8 +208,9 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
expectedGroup: &api.Group{
|
expectedGroup: &api.Group{
|
||||||
Id: "id-was-set",
|
Id: "id-was-set",
|
||||||
Name: "Default POSTed Group",
|
Name: "Default POSTed Group",
|
||||||
|
Issued: &groupIssuedAPI,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -208,8 +230,9 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
[]byte(`{"Name":"Default POSTed Group"}`)),
|
[]byte(`{"Name":"Default POSTed Group"}`)),
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedGroup: &api.Group{
|
expectedGroup: &api.Group{
|
||||||
Id: "id-existed",
|
Id: "id-existed",
|
||||||
Name: "Default POSTed Group",
|
Name: "Default POSTed Group",
|
||||||
|
Issued: &groupIssuedAPI,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -230,6 +253,19 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Write Group PUT not not change Issue",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/groups/id-jwt-group",
|
||||||
|
requestBody: bytes.NewBuffer(
|
||||||
|
[]byte(`{"Name":"changed","Issued":"api"}`)),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedGroup: &api.Group{
|
||||||
|
Id: "id-jwt-group",
|
||||||
|
Name: "changed",
|
||||||
|
Issued: &groupIssuedJWT,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user")
|
||||||
@@ -271,3 +307,79 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteGroup(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatus int
|
||||||
|
expectedBody bool
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Try to delete linked group",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/groups/linked-grp",
|
||||||
|
expectedStatus: http.StatusBadRequest,
|
||||||
|
expectedBody: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Try to cause internal error",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/groups/invalid-grp",
|
||||||
|
expectedStatus: http.StatusInternalServerError,
|
||||||
|
expectedBody: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Try to cause internal error",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/groups/invalid-grp",
|
||||||
|
expectedStatus: http.StatusInternalServerError,
|
||||||
|
expectedBody: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Delete group",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/groups/any-grp",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
adminUser := server.NewAdminUser("test_user")
|
||||||
|
p := initGroupTestData(adminUser)
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("I don't know what I expected; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := recorder.Code; status != tc.expectedStatus {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||||
|
status, tc.expectedStatus, string(content))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectedBody {
|
||||||
|
got := &util.ErrorResponse{}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(content, &got); err != nil {
|
||||||
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, got.Code, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
|
|||||||
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
||||||
|
apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
|
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
case "bearer":
|
case "bearer":
|
||||||
err := m.CheckJWTFromRequest(w, r)
|
err := m.CheckJWTFromRequest(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Error when validating JWT claims: %s", err.Error())
|
log.Errorf("Error when validating JWT claims: %s", err.Error())
|
||||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -208,6 +208,37 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(w, users)
|
util.WriteJSONObject(w, users)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUser resend invitations to users who haven't activated their accounts,
|
||||||
|
// prior to the expiration period.
|
||||||
|
func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
targetUserID := vars["userId"]
|
||||||
|
if len(targetUserID) == 0 {
|
||||||
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(w, emptyObject{})
|
||||||
|
}
|
||||||
|
|
||||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||||
autoGroups := user.AutoGroups
|
autoGroups := user.AutoGroups
|
||||||
if autoGroups == nil {
|
if autoGroups == nil {
|
||||||
|
|||||||
@@ -98,6 +98,17 @@ func initUsersTestData() *UsersHandler {
|
|||||||
}
|
}
|
||||||
return info, nil
|
return info, nil
|
||||||
},
|
},
|
||||||
|
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
|
if initiatorUserID != existingUserID {
|
||||||
|
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetUserID == notFoundUserID {
|
||||||
|
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
@@ -340,6 +351,51 @@ func TestCreateUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInviteUser(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatus int
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
requestVars map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Invite User with Existing User",
|
||||||
|
requestType: http.MethodPost,
|
||||||
|
requestPath: "/api/users/" + existingUserID + "/invite",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
requestVars: map[string]string{"userId": existingUserID},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invite User with missing user_id",
|
||||||
|
requestType: http.MethodPost,
|
||||||
|
requestPath: "/api/users/" + notFoundUserID + "/invite",
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
requestVars: map[string]string{"userId": notFoundUserID},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
userHandler := initUsersTestData()
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
req = mux.SetURLVars(req, tc.requestVars)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
userHandler.InviteUser(rr, req)
|
||||||
|
|
||||||
|
res := rr.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if status := rr.Code; status != tc.expectedStatus {
|
||||||
|
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||||
|
status, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDeleteUser(t *testing.T) {
|
func TestDeleteUser(t *testing.T) {
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
|
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
|
||||||
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -58,14 +63,9 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
|
|||||||
|
|
||||||
// WriteErrorResponse prepares and writes an error response i nJSON
|
// WriteErrorResponse prepares and writes an error response i nJSON
|
||||||
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||||
type errorResponse struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Code int `json:"code"`
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(httpStatus)
|
w.WriteHeader(httpStatus)
|
||||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||||
err := json.NewEncoder(w).Encode(&errorResponse{
|
err := json.NewEncoder(w).Encode(&ErrorResponse{
|
||||||
Message: errMsg,
|
Message: errMsg,
|
||||||
Code: httpStatus,
|
Code: httpStatus,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -98,6 +98,11 @@ type userExportJobStatusResponse struct {
|
|||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// userVerificationJobRequest is a user verification request struct
|
||||||
|
type userVerificationJobRequest struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
// auth0Profile represents an Auth0 user profile response
|
// auth0Profile represents an Auth0 user profile response
|
||||||
type auth0Profile struct {
|
type auth0Profile struct {
|
||||||
AccountID string `json:"wt_account_id"`
|
AccountID string `json:"wt_account_id"`
|
||||||
@@ -689,6 +694,48 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
|
|||||||
return &createResp, nil
|
return &createResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||||
|
userVerificationReq := userVerificationJobRequest{
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := am.helper.Marshal(userVerificationReq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := am.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Couldn't get job response %v", err)
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error while closing invite user response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||||
// If the status is "completed", then return the downloadLink
|
// If the status is "completed", then return the downloadLink
|
||||||
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
||||||
|
|||||||
@@ -440,6 +440,12 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
|
|||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (am *AuthentikManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
||||||
jwtToken, err := am.credentials.Authenticate()
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -448,6 +448,12 @@ func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (am *AzureManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Add("$select", extensionFields)
|
q.Add("$select", extensionFields)
|
||||||
|
|||||||
@@ -247,6 +247,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
|
|||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
||||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||||
// If that fails, it falls back to using the default Google credentials path.
|
// If that fails, it falls back to using the default Google credentials path.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type Manager interface {
|
|||||||
GetAllAccounts() (map[string][]*UserData, error)
|
GetAllAccounts() (map[string][]*UserData, error)
|
||||||
CreateUser(email string, name string, accountID string) (*UserData, error)
|
CreateUser(email string, name string, accountID string) (*UserData, error)
|
||||||
GetUserByEmail(email string) ([]*UserData, error)
|
GetUserByEmail(email string) ([]*UserData, error)
|
||||||
|
InviteUserByID(userID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientConfig defines common client configuration for all IdP manager
|
// ClientConfig defines common client configuration for all IdP manager
|
||||||
|
|||||||
@@ -461,6 +461,12 @@ func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppM
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (km *KeycloakManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
||||||
attrs := keycloakUserAttributes{}
|
attrs := keycloakUserAttributes{}
|
||||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||||
|
|||||||
@@ -270,21 +270,32 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
|
|||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||||
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||||
var pendingInvite bool
|
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
||||||
if appMetadata.WTPendingInvite != nil {
|
if err != nil {
|
||||||
pendingInvite = *appMetadata.WTPendingInvite
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, resp, err := om.client.User.UpdateUser(context.Background(), userID,
|
if resp.StatusCode != http.StatusOK {
|
||||||
okta.User{
|
if om.appMetrics != nil {
|
||||||
Profile: &okta.UserProfile{
|
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
wtAccountID: appMetadata.WTAccountID,
|
}
|
||||||
wtPendingInvite: pendingInvite,
|
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
|
||||||
},
|
}
|
||||||
},
|
|
||||||
nil,
|
profile := *user.Profile
|
||||||
)
|
|
||||||
|
if appMetadata.WTPendingInvite != nil {
|
||||||
|
profile[wtPendingInvite] = *appMetadata.WTPendingInvite
|
||||||
|
}
|
||||||
|
|
||||||
|
if appMetadata.WTAccountID != "" {
|
||||||
|
profile[wtAccountID] = appMetadata.WTAccountID
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Profile = &profile
|
||||||
|
_, resp, err = om.client.User.UpdateUser(context.Background(), userID, *user, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Println(err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,10 +313,18 @@ func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetad
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (om *OktaManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// updateUserProfileSchema updates the Okta user schema to include custom fields,
|
// updateUserProfileSchema updates the Okta user schema to include custom fields,
|
||||||
// wt_account_id and wt_pending_invite.
|
// wt_account_id and wt_pending_invite.
|
||||||
func updateUserProfileSchema(client *okta.Client) error {
|
func updateUserProfileSchema(client *okta.Client) error {
|
||||||
required := true
|
// Ensure Okta doesn't enforce user input for these fields, as they are solely used by Netbird
|
||||||
|
userPermissions := []*okta.UserSchemaAttributePermission{{Action: "HIDE", Principal: "SELF"}}
|
||||||
|
|
||||||
_, resp, err := client.UserSchema.UpdateUserProfile(
|
_, resp, err := client.UserSchema.UpdateUserProfile(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
"default",
|
"default",
|
||||||
@@ -316,18 +335,20 @@ func updateUserProfileSchema(client *okta.Client) error {
|
|||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]*okta.UserSchemaAttribute{
|
Properties: map[string]*okta.UserSchemaAttribute{
|
||||||
wtAccountID: {
|
wtAccountID: {
|
||||||
MaxLength: 100,
|
MaxLength: 100,
|
||||||
MinLength: 1,
|
MinLength: 1,
|
||||||
Required: &required,
|
Required: new(bool),
|
||||||
Scope: "NONE",
|
Scope: "NONE",
|
||||||
Title: "Wt Account Id",
|
Title: "Wt Account Id",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
|
Permissions: userPermissions,
|
||||||
},
|
},
|
||||||
wtPendingInvite: {
|
wtPendingInvite: {
|
||||||
Required: new(bool),
|
Required: new(bool),
|
||||||
Scope: "NONE",
|
Scope: "NONE",
|
||||||
Title: "Wt Pending Invite",
|
Title: "Wt Pending Invite",
|
||||||
Type: "boolean",
|
Type: "boolean",
|
||||||
|
Permissions: userPermissions,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -441,6 +441,12 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMe
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (zm *ZitadelManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// getUserMetadata requests user metadata from zitadel via ID.
|
// getUserMetadata requests user metadata from zitadel via ID.
|
||||||
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
||||||
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
||||||
|
|||||||
@@ -1,9 +1,15 @@
|
|||||||
package jwtclaims
|
package jwtclaims
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
// AuthorizationClaims stores authorization information from JWTs
|
// AuthorizationClaims stores authorization information from JWTs
|
||||||
type AuthorizationClaims struct {
|
type AuthorizationClaims struct {
|
||||||
UserId string
|
UserId string
|
||||||
AccountId string
|
AccountId string
|
||||||
Domain string
|
Domain string
|
||||||
DomainCategory string
|
DomainCategory string
|
||||||
|
|
||||||
|
Raw jwt.MapClaims
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,7 +73,9 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
|||||||
// FromToken extracts claims from the token (after auth)
|
// FromToken extracts claims from the token (after auth)
|
||||||
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
jwtClaims := AuthorizationClaims{}
|
jwtClaims := AuthorizationClaims{
|
||||||
|
Raw: claims,
|
||||||
|
}
|
||||||
userID, ok := claims[c.userIDClaim].(string)
|
userID, ok := claims[c.userIDClaim].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return jwtClaims
|
return jwtClaims
|
||||||
|
|||||||
@@ -48,6 +48,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
|||||||
Domain: "test.com",
|
Domain: "test.com",
|
||||||
AccountId: "testAcc",
|
AccountId: "testAcc",
|
||||||
DomainCategory: "public",
|
DomainCategory: "public",
|
||||||
|
Raw: jwt.MapClaims{
|
||||||
|
"https://login/wt_account_domain": "test.com",
|
||||||
|
"https://login/wt_account_domain_category": "public",
|
||||||
|
"https://login/wt_account_id": "testAcc",
|
||||||
|
"sub": "test",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testingFunc: require.EqualValues,
|
testingFunc: require.EqualValues,
|
||||||
expectedMSG: "extracted claims should match input claims",
|
expectedMSG: "extracted claims should match input claims",
|
||||||
@@ -59,6 +65,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
|||||||
inputAuthorizationClaims: AuthorizationClaims{
|
inputAuthorizationClaims: AuthorizationClaims{
|
||||||
UserId: "test",
|
UserId: "test",
|
||||||
AccountId: "testAcc",
|
AccountId: "testAcc",
|
||||||
|
Raw: jwt.MapClaims{
|
||||||
|
"https://login/wt_account_id": "testAcc",
|
||||||
|
"sub": "test",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testingFunc: require.EqualValues,
|
testingFunc: require.EqualValues,
|
||||||
expectedMSG: "extracted claims should match input claims",
|
expectedMSG: "extracted claims should match input claims",
|
||||||
@@ -70,6 +80,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
|||||||
inputAuthorizationClaims: AuthorizationClaims{
|
inputAuthorizationClaims: AuthorizationClaims{
|
||||||
UserId: "test",
|
UserId: "test",
|
||||||
Domain: "test.com",
|
Domain: "test.com",
|
||||||
|
Raw: jwt.MapClaims{
|
||||||
|
"https://login/wt_account_domain": "test.com",
|
||||||
|
"sub": "test",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testingFunc: require.EqualValues,
|
testingFunc: require.EqualValues,
|
||||||
expectedMSG: "extracted claims should match input claims",
|
expectedMSG: "extracted claims should match input claims",
|
||||||
@@ -82,6 +96,11 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
|||||||
UserId: "test",
|
UserId: "test",
|
||||||
Domain: "test.com",
|
Domain: "test.com",
|
||||||
AccountId: "testAcc",
|
AccountId: "testAcc",
|
||||||
|
Raw: jwt.MapClaims{
|
||||||
|
"https://login/wt_account_domain": "test.com",
|
||||||
|
"https://login/wt_account_id": "testAcc",
|
||||||
|
"sub": "test",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testingFunc: require.EqualValues,
|
testingFunc: require.EqualValues,
|
||||||
expectedMSG: "extracted claims should match input claims",
|
expectedMSG: "extracted claims should match input claims",
|
||||||
@@ -92,6 +111,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
|||||||
inputAudiance: "https://login/",
|
inputAudiance: "https://login/",
|
||||||
inputAuthorizationClaims: AuthorizationClaims{
|
inputAuthorizationClaims: AuthorizationClaims{
|
||||||
UserId: "test",
|
UserId: "test",
|
||||||
|
Raw: jwt.MapClaims{
|
||||||
|
"sub": "test",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
testingFunc: require.EqualValues,
|
testingFunc: require.EqualValues,
|
||||||
expectedMSG: "extracted claims should match input claims",
|
expectedMSG: "extracted claims should match input claims",
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
|
|||||||
|
|
||||||
// Check if there was an error in parsing...
|
// Check if there was an error in parsing...
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("error parsing token: %v", err)
|
log.Errorf("error parsing token: %v", err)
|
||||||
return nil, fmt.Errorf("Error parsing token: %w", err)
|
return nil, fmt.Errorf("Error parsing token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ type ConnManager interface {
|
|||||||
type Worker struct {
|
type Worker struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
id string
|
id string
|
||||||
|
idpManager string
|
||||||
dataSource DataSource
|
dataSource DataSource
|
||||||
connManager ConnManager
|
connManager ConnManager
|
||||||
startupTime time.Time
|
startupTime time.Time
|
||||||
@@ -66,11 +67,12 @@ type Worker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewWorker returns a metrics worker
|
// NewWorker returns a metrics worker
|
||||||
func NewWorker(ctx context.Context, id string, dataSource DataSource, connManager ConnManager) *Worker {
|
func NewWorker(ctx context.Context, id string, dataSource DataSource, connManager ConnManager, idpManager string) *Worker {
|
||||||
currentTime := time.Now()
|
currentTime := time.Now()
|
||||||
return &Worker{
|
return &Worker{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
id: id,
|
id: id,
|
||||||
|
idpManager: idpManager,
|
||||||
dataSource: dataSource,
|
dataSource: dataSource,
|
||||||
connManager: connManager,
|
connManager: connManager,
|
||||||
startupTime: currentTime,
|
startupTime: currentTime,
|
||||||
@@ -277,6 +279,7 @@ func (w *Worker) generateProperties() properties {
|
|||||||
metricsProperties["min_active_peer_version"] = minActivePeerVersion
|
metricsProperties["min_active_peer_version"] = minActivePeerVersion
|
||||||
metricsProperties["max_active_peer_version"] = maxActivePeerVersion
|
metricsProperties["max_active_peer_version"] = maxActivePeerVersion
|
||||||
metricsProperties["ui_clients"] = uiClient
|
metricsProperties["ui_clients"] = uiClient
|
||||||
|
metricsProperties["idp_manager"] = w.idpManager
|
||||||
|
|
||||||
for protocol, count := range rulesProtocol {
|
for protocol, count := range rulesProtocol {
|
||||||
metricsProperties["rules_protocol_"+protocol] = count
|
metricsProperties["rules_protocol_"+protocol] = count
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type MockAccountManager struct {
|
|||||||
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
|
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
|
||||||
SaveGroupFunc func(accountID, userID string, group *server.Group) error
|
SaveGroupFunc func(accountID, userID string, group *server.Group) error
|
||||||
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
|
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
|
||||||
DeleteGroupFunc func(accountID, groupID string) error
|
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||||
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
||||||
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
||||||
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
|
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
|
||||||
@@ -81,6 +81,7 @@ type MockAccountManager struct {
|
|||||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||||
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error)
|
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error)
|
||||||
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
|
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
|
||||||
|
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||||
@@ -274,9 +275,9 @@ func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, oper
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error {
|
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||||
if am.DeleteGroupFunc != nil {
|
if am.DeleteGroupFunc != nil {
|
||||||
return am.DeleteGroupFunc(accountID, groupID)
|
return am.DeleteGroupFunc(accountId, userId, groupID)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
|
||||||
}
|
}
|
||||||
@@ -500,6 +501,13 @@ func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID strin
|
|||||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
|
if am.InviteUserFunc != nil {
|
||||||
|
return am.InviteUserFunc(accountID, initiatorUserID, targetUserID)
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
|
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
|
||||||
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
if am.GetNameServerGroupFunc != nil {
|
if am.GetNameServerGroupFunc != nil {
|
||||||
|
|||||||
@@ -78,11 +78,14 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var setupKey *SetupKey
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId)
|
||||||
for _, key := range account.SetupKeys {
|
if err != nil {
|
||||||
if key.Type == SetupKeyReusable {
|
return
|
||||||
setupKey = key
|
}
|
||||||
}
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating setup key")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||||
@@ -328,7 +331,15 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKey := getSetupKey(account, SetupKeyReusable)
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating setup key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -394,7 +405,15 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// two peers one added by a regular user and one with a setup key
|
// two peers one added by a regular user and one with a setup key
|
||||||
setupKey := getSetupKey(account, SetupKeyReusable)
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating setup key")
|
||||||
|
return
|
||||||
|
}
|
||||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -470,13 +489,3 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.NotNil(t, peer)
|
assert.NotNil(t, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSetupKey(account *Account, keyType SetupKeyType) *SetupKey {
|
|
||||||
var setupKey *SetupKey
|
|
||||||
for _, key := range account.SetupKeys {
|
|
||||||
if key.Type == keyType {
|
|
||||||
setupKey = key
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return setupKey
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"fmt"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@@ -240,7 +242,15 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
|
|||||||
peersExists := make(map[string]struct{})
|
peersExists := make(map[string]struct{})
|
||||||
rules := make([]*FirewallRule, 0)
|
rules := make([]*FirewallRule, 0)
|
||||||
peers := make([]*Peer, 0)
|
peers := make([]*Peer, 0)
|
||||||
|
|
||||||
|
all, err := a.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get group all: %v", err)
|
||||||
|
all = &Group{}
|
||||||
|
}
|
||||||
|
|
||||||
return func(rule *PolicyRule, groupPeers []*Peer, direction int) {
|
return func(rule *PolicyRule, groupPeers []*Peer, direction int) {
|
||||||
|
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||||
for _, peer := range groupPeers {
|
for _, peer := range groupPeers {
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
continue
|
continue
|
||||||
@@ -250,29 +260,33 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
|
|||||||
peersExists[peer.ID] = struct{}{}
|
peersExists[peer.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
fwRule := FirewallRule{
|
fr := FirewallRule{
|
||||||
PeerIP: peer.IP.String(),
|
PeerIP: peer.IP.String(),
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Action: string(rule.Action),
|
Action: string(rule.Action),
|
||||||
Protocol: string(rule.Protocol),
|
Protocol: string(rule.Protocol),
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := fmt.Sprintf("%s%d", peer.ID+peer.IP.String(), direction)
|
if isAll {
|
||||||
ruleID += string(rule.Protocol) + string(rule.Action) + strings.Join(rule.Ports, ",")
|
fr.PeerIP = "0.0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||||
|
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
|
||||||
if _, ok := rulesExists[ruleID]; ok {
|
if _, ok := rulesExists[ruleID]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rulesExists[ruleID] = struct{}{}
|
rulesExists[ruleID] = struct{}{}
|
||||||
|
|
||||||
if len(rule.Ports) == 0 {
|
if len(rule.Ports) == 0 {
|
||||||
rules = append(rules, &fwRule)
|
rules = append(rules, &fr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, port := range rule.Ports {
|
for _, port := range rule.Ports {
|
||||||
addRule := fwRule
|
pr := fr // clone rule and add set new port
|
||||||
addRule.Port = port
|
pr.Port = port
|
||||||
rules = append(rules, &addRule)
|
rules = append(rules, &pr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, func() ([]*Peer, []*FirewallRule) {
|
}, func() ([]*Peer, []*FirewallRule) {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user