Compare commits

...

67 Commits

Author SHA1 Message Date
Givi Khojanashvili
6dee89379b Feat optimize acl performance iptables (#1025)
* use ipset for iptables

* Update unit-tests for iptables

* Remove debug code

* Update dependencies

* Create separate sets for dPort and sPort rules

* Fix iptables tests

* Fix 0.0.0.0 processing in iptables with ipset
2023-07-24 13:00:23 +02:00
Maycon Santos
76db4f801a Record idp manager type (#1027)
This allows to define priority on support different managers
2023-07-22 19:30:59 +02:00
Zoltan Papp
6c2ed4b4f2 Add default forward rule (#1021)
* Add default forward rule

* Fix

* Add multiple forward rules

* Fix delete rule error handling
2023-07-22 18:39:23 +02:00
Maycon Santos
2541c78dd0 Use error level for JWT parsing error logs (#1026) 2023-07-22 17:56:27 +02:00
Yury Gargay
97b6e79809 Fix DefaultAccountManager GetGroupsFromTheToken false positive tests (#1019)
This fixes the test logic creates copy of account with empty id and
re-pointing the indices to it.

Also, adds additional check for empty ID in SaveAccount method of FileStore.
2023-07-22 15:54:08 +04:00
Givi Khojanashvili
6ad3847615 Fix nfset not binds to the rule (#1024) 2023-07-21 17:45:58 +02:00
Bethuel Mmbaga
a4d830ef83 Fix Okta IDP device authorization (#1023)
* hide okta netbird attributes fields

* fix: update full user profile
2023-07-21 09:34:49 +02:00
pascal-fischer
9e540cd5b4 Merge pull request #1016 from surik/filestore-index-deletion-optimisation
Do not persist filestore when deleting indices
2023-07-20 18:07:33 +02:00
Zoltan Papp
3027d8f27e Sync the iptables/nftables usage with acl logic (#1017) 2023-07-19 19:10:27 +02:00
Givi Khojanashvili
e69ec6ab6a Optimize ACL performance (#994)
* Optimize rules with All groups

* Use IP sets in ACLs (nftables implementation)

* Fix squash rule when we receive optimized rules list from management
2023-07-18 13:12:50 +04:00
Yury Gargay
7ddde41c92 Do not persist filestore when deleting indices
As both TokenID2UserID and HashedPAT2TokenID are in-memory indices and
not stored in the file.
2023-07-17 11:52:45 +02:00
Zoltan Papp
7ebe58f20a Feature/permanent dns (#967)
* Add DNS list argument for mobile client

* Write testable code

Many places are checked the wgInterface != nil condition.
It is doing it just because to avoid the real wgInterface creation for tests.
Instead of this involve a wgInterface interface what is moc-able.

* Refactor the DNS server internal code structure

With the fake resolver has been involved several
if-else statement and generated some unused
variables to distinguish the listener and fake
resolver solutions at running time. With this
commit the fake resolver and listener based
solution has been moved into two separated
structure. Name of this layer is the 'service'.
With this modification the unit test looks
simpler and open the option to add new logic for
the permanent DNS service usage for mobile
systems.



* Remove is running check in test

We can not ensure the state well so remove this
check. The test will fail if the server is not
running well.
2023-07-14 21:56:22 +02:00
Zoltan Papp
9c2c0e7934 Check links of groups before delete it (#1010)
* Check links of groups before delete it

* Add delete group handler test

* Rename dns error msg

* Add delete group test

* Remove rule check

The policy cover this scenario

* Fix test

* Check disabled management grps

* Change error message

* Add new activity for group delete event
2023-07-14 20:45:40 +02:00
pascal-fischer
c6af1037d9 FIx error on ip6tables not available (#999)
* adding check operation to confirm if ip*tables is available

* linter

* linter
2023-07-14 20:44:35 +02:00
Bethuel Mmbaga
5cb9a126f1 Fix pre-shared key not persistent (#1011)
* update pre-shared key if new key is not empty

* add unit test for empty pre-shared key
2023-07-13 10:49:15 +02:00
pascal-fischer
f40951cdf5 Merge pull request #991 from netbirdio/fix/improve_uspfilter_performance
Improve userspace filter performance
2023-07-12 18:02:29 +02:00
Pascal Fischer
6e264d9de7 fix rule order to solve DNS resolver issue 2023-07-11 19:58:21 +02:00
Bethuel Mmbaga
42db9773f4 Remove unused netbird UI dependencies (#1007)
* remove unused netbird-ui dependencies in deb package

* build netbird-ui with support for legacy appindicator

* add rpm package dendencies

* add binary build package

* remove dependencies
2023-07-10 21:09:16 +02:00
Bethuel Mmbaga
bb9f6f6d0a Add API Endpoint for Resending User Invitations in Auth0 (#989)
* add request handler for sending invite

* add InviteUser method to account manager interface

* add InviteUser mock

* add invite user endpoint to user handler

* add InviteUserByID to manager interface

* implement InviteUserByID in all idp managers

* resend user invitation

* add invite user handler tests

* refactor

* user userID for sending invitation

* fix typo

* refactor

* pass userId in url params
2023-07-03 12:20:19 +02:00
Yury Gargay
829ce6573e Fix broken links in README.md (#992) 2023-06-29 11:42:55 +02:00
Maycon Santos
a366d9e208 Prevent sending nameserver configuration when peer is set as NS (#962)
* Prevent sending nameserver configuration when peer is set as NS

* Add DNS filter tests
2023-06-28 17:29:02 +02:00
Pascal Fischer
e074c24487 add type for RuleSet 2023-06-28 14:09:23 +02:00
Pascal Fischer
54fe05f6d8 fix test 2023-06-28 10:35:29 +02:00
Pascal Fischer
33a155d9aa fix all rules check 2023-06-28 03:03:01 +02:00
Pascal Fischer
51878659f8 remove Rule index map 2023-06-28 02:50:12 +02:00
pascal-fischer
c000c05435 Merge pull request #983 from netbirdio/fix/ssh_connection_freeze
Fix ssh connection freeze
2023-06-27 18:10:30 +02:00
Pascal Fischer
b39ffef22c add missing all rule 2023-06-27 17:44:05 +02:00
Pascal Fischer
d96f882acb seems to work but delete fails 2023-06-27 17:26:15 +02:00
Misha Bragin
d409219b51 Don't create setup keys on new account (#972) 2023-06-27 17:17:24 +02:00
Givi Khojanashvili
8b619a8224 JWT Groups support (#966)
Get groups from the JWT tokens if the feature enabled for the account
2023-06-27 18:51:05 +04:00
Maycon Santos
ed075bc9b9 Refactor: Configurable supported scopes (#985)
* Refactor: Configurable supported scopes

Previously, supported scopes were hardcoded and limited to Auth0
and Keycloak. This update removes the default set of values,
providing flexibility. The value to be set for each Identity
Provider (IDP) is specified in their respective documentation.

* correct var

* correct var

* skip fetching scopes from openid-configuration
2023-06-25 13:59:45 +02:00
Pascal Fischer
8eb098d6fd add sleep and comment 2023-06-23 17:02:34 +02:00
Pascal Fischer
68a8687c80 fix linter 2023-06-23 16:45:07 +02:00
Pascal Fischer
f7d97b02fd fix error codes on cli 2023-06-23 16:27:10 +02:00
Pascal Fischer
2691e729cd fix ssh 2023-06-23 12:20:14 +02:00
Givi Khojanashvili
b524a9d49d Fix use wrpped device in windows (#981) 2023-06-23 10:01:22 +02:00
Givi Khojanashvili
774d8e955c Fix disabled DNS resolver fail (#978)
Fix fail of DNS when it disabled in the settings
2023-06-22 16:59:21 +04:00
Givi Khojanashvili
c20f98c8b6 ACL firewall manager fix/improvement (#970)
* ACL firewall manager fix/improvement

Fix issue with rule squashing, it contained issue when calculated
total amount of IPs in the Peer map (doesn't included offline peers).
That why squashing not worked.
Also this commit changes the rules apply behaviour. Instead policy:
1. Apply all rules from network map
2. Remove all previous applied rules
We do:
1. Apply only new rules
2. Remove outdated rules
Why first variant was implemented: because when you have drop policy
it is important in which order order you rules are and you need totally
clean previous state to apply the new. But in the release we didn't
include drop policy so we can do this improvement.

* Print log message about processed ACL rules
2023-06-20 20:33:41 +02:00
Zoltan Papp
20ae540fb1 Fix the stop procedure in DefaultDns (#971) 2023-06-20 20:33:26 +02:00
Bethuel
58cfa2bb17 Add Google Workspace IdP (#949)
Added integration with Google Workspace user directory API.
2023-06-20 19:15:36 +02:00
pascal-fischer
06005cc10e Merge pull request #968 from netbirdio/chore/extend_gitignore_for_multiple_configs
Extend gitignore to ignore multiple configs
2023-06-19 17:17:12 +02:00
Pascal Fischer
1a3e377304 extend gitignore to ignore multiple config files 2023-06-19 15:07:27 +02:00
Zoltan Papp
dd29f4c01e Reduce the peer status notifications (#956)
Reduce the peer status notifications

When receive new network map invoke multiple notifications for 
every single peers. It cause high cpu usage We handle the in a 
batch the peer notification in update network map.

- Remove the unnecessary UpdatePeerFQDN calls in addNewPeer
- Fix notification in RemovePeer function
- Involve FinishPeerListModifications logic
2023-06-19 11:20:34 +02:00
pascal-fischer
cb7ecd1cc4 Merge pull request #945 from netbirdio/feat/refactor_route_adding_in_client
Refactor check logic when adding routes
2023-06-19 10:16:22 +02:00
Pascal Fischer
b5d8142705 test windows 2023-06-12 16:22:53 +02:00
Pascal Fischer
f45eb1a1da test windows 2023-06-12 16:12:24 +02:00
Pascal Fischer
2567006412 test windows 2023-06-12 16:01:06 +02:00
Pascal Fischer
b92107efc8 test windows 2023-06-12 15:38:47 +02:00
Pascal Fischer
5d19811331 test windows 2023-06-12 15:26:28 +02:00
Pascal Fischer
697d41c94e test windows 2023-06-12 15:14:51 +02:00
Pascal Fischer
75d541f967 test windows 2023-06-12 14:56:30 +02:00
Pascal Fischer
7dfbb71f7a test windows 2023-06-12 12:49:21 +02:00
Pascal Fischer
a5d14c92ff test windows 2023-06-12 12:16:00 +02:00
Pascal Fischer
ce091ab42b test windows 2023-06-12 11:43:18 +02:00
Pascal Fischer
d2fad1cfd9 testing windows 2023-06-12 11:06:49 +02:00
Pascal Fischer
0b5594f145 testing windows 2023-06-09 19:17:26 +02:00
Pascal Fischer
9beaa91db9 testing windows 2023-06-09 19:15:39 +02:00
Pascal Fischer
c8b4c08139 split systemops for operating systems and add linux 2023-06-09 18:48:21 +02:00
Pascal Fischer
dad5501a44 split systemops for operating systems and add linux 2023-06-09 18:40:35 +02:00
Pascal Fischer
1ced2462c1 split systemops for operating systems and add linux 2023-06-09 18:36:49 +02:00
Pascal Fischer
64adaeb276 split systemops for operating systems and add linux 2023-06-09 18:30:36 +02:00
Pascal Fischer
6e26d03fb8 split systemops for operating systems and add linux 2023-06-09 18:27:09 +02:00
Pascal Fischer
493ddb4fe3 Revert "hacky all-operating-systems solution"
This reverts commit 75fac258e7.
2023-06-09 17:59:06 +02:00
Pascal Fischer
75fac258e7 hacky all-operating-systems solution 2023-06-09 17:40:10 +02:00
Pascal Fischer
bc8ee8fc3c add tests 2023-06-09 16:18:48 +02:00
Pascal Fischer
3724323f76 test still failing 2023-06-09 15:33:22 +02:00
Pascal Fischer
3ef33874b1 change checks before route adding to not only check for default gateway (test missing) 2023-06-09 12:35:57 +02:00
112 changed files with 4462 additions and 1026 deletions

View File

@@ -116,7 +116,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install rsrc - name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2 run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc - name: Generate windows rsrc

View File

@@ -53,6 +53,7 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none" CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
- name: check values - name: check values
working-directory: infrastructure_files working-directory: infrastructure_files

7
.gitignore vendored
View File

@@ -7,8 +7,15 @@ bin/
conf.json conf.json
http-cmds.sh http-cmds.sh
infrastructure_files/management.json infrastructure_files/management.json
infrastructure_files/management-*.json
infrastructure_files/docker-compose.yml infrastructure_files/docker-compose.yml
infrastructure_files/openid-configuration.json
infrastructure_files/turnserver.conf
management/management
client/client
client/client.exe
*.syso *.syso
client/.distfiles/ client/.distfiles/
infrastructure_files/setup.env infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode .vscode

View File

@@ -11,6 +11,8 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows - id: netbird-ui-windows
@@ -55,9 +57,6 @@ nfpms:
- src: client/ui/disconnected.png - src: client/ui/disconnected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- libayatana-appindicator3-1
- libgtk-3-dev
- libappindicator3-dev
- netbird - netbird
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
@@ -75,9 +74,6 @@ nfpms:
- src: client/ui/disconnected.png - src: client/ui/disconnected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- libayatana-appindicator3-1
- libgtk-3-dev
- libappindicator3-dev
- netbird - netbird
uploads: uploads:

View File

@@ -70,9 +70,9 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
### Start using NetBird ### Start using NetBird
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/). - Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart). - See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started).
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting). - If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide).
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms. - Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms.
- Web UI [repository](https://github.com/netbirdio/dashboard). - Web UI [repository](https://github.com/netbirdio/dashboard).
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube. - 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
@@ -91,7 +91,7 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/> <img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
</p> </p>
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details. See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
### Roadmap ### Roadmap
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2) - [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -35,6 +36,11 @@ type RouteListener interface {
routemanager.RouteListener routemanager.RouteListener
} }
// DnsReadyListener export internal dns ReadyListener for mobile
type DnsReadyListener interface {
dns.ReadyListener
}
func init() { func init() {
formatter.SetLogcatFormatter(log.StandardLogger()) formatter.SetLogcatFormatter(log.StandardLogger())
} }
@@ -49,6 +55,7 @@ type Client struct {
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
routeListener routemanager.RouteListener routeListener routemanager.RouteListener
onHostDnsFn func([]string)
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -65,7 +72,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener) error { func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -90,7 +97,8 @@ func (c *Client) Run(urlOpener URLOpener) error {
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener) c.onHostDnsFn = func([]string) {}
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -126,6 +134,17 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos} return &PeerInfoArray{items: peerInfos}
} }
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()
if err != nil {
return err
}
dnsServer.OnUpdatedHostDNSServer(list.items)
return nil
}
// SetConnectionListener set the network connection listener // SetConnectionListener set the network connection listener
func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) SetConnectionListener(listener ConnectionListener) {
c.recorder.SetConnectionListener(listener) c.recorder.SetConnectionListener(listener)

View File

@@ -0,0 +1,26 @@
package android
import "fmt"
// DNSList is a wrapper of []string
type DNSList struct {
items []string
}
// Add new DNS address to the collection
func (array *DNSList) Add(s string) {
array.items = append(array.items, s)
}
// Get return an element of the collection
func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range")
}
return array.items[i], nil
}
// Size return with the size of the collection
func (array *DNSList) Size() int {
return len(array.items)
}

View File

@@ -0,0 +1,24 @@
package android
import "testing"
func TestDNSList_Get(t *testing.T) {
l := DNSList{
items: make([]string, 1),
}
_, err := l.Get(0)
if err != nil {
t.Errorf("invalid error: %s", err)
}
_, err = l.Get(-1)
if err == nil {
t.Errorf("expected error but got nil")
}
_, err = l.Get(1)
if err == nil {
t.Errorf("expected error but got nil")
}
}

View File

@@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
go func() { go func() {
// blocking // blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
log.Print(err) log.Debug(err)
os.Exit(1)
} }
cancel() cancel()
}() }()
@@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey) c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
if err != nil { if err != nil {
cmd.Printf("Error: %v\n", err) cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. " + cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" + "You can verify the connection by running:\n\n" +
"Run the status command: \n\n" + " netbird status\n\n")
" netbird status\n\n" + return err
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
return nil
} }
go func() { go func() {
<-ctx.Done() <-ctx.Done()

View File

@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil) return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@@ -51,6 +51,7 @@ type Manager interface {
dPort *Port, dPort *Port,
direction RuleDirection, direction RuleDirection,
action Action, action Action,
ipsetName string,
comment string, comment string,
) (Rule, error) ) (Rule, error)
@@ -60,5 +61,8 @@ type Manager interface {
// Reset firewall to the default state // Reset firewall to the default state
Reset() error Reset() error
// Flush the changes to firewall controller
Flush() error
// TODO: migrate routemanager firewal actions to this interface // TODO: migrate routemanager firewal actions to this interface
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
fw "github.com/netbirdio/netbird/client/firewall" fw "github.com/netbirdio/netbird/client/firewall"
@@ -35,6 +36,8 @@ type Manager struct {
inputDefaultRuleSpecs []string inputDefaultRuleSpecs []string
outputDefaultRuleSpecs []string outputDefaultRuleSpecs []string
wgIface iFaceMapper wgIface iFaceMapper
rulesets map[string]ruleset
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -43,6 +46,11 @@ type iFaceMapper interface {
Address() iface.WGAddress Address() iface.WGAddress
} }
type ruleset struct {
rule *Rule
ips map[string]string
}
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
@@ -51,6 +59,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()}, "-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
outputDefaultRuleSpecs: []string{ outputDefaultRuleSpecs: []string{
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()}, "-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
rulesets: make(map[string]ruleset),
}
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
} }
// init clients for booth ipv4 and ipv6 // init clients for booth ipv4 and ipv6
@@ -58,14 +71,18 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("iptables is not installed in the system or not supported")
} }
if isIptablesClientAvailable(ipv4Client) {
m.ipv4Client = ipv4Client m.ipv4Client = ipv4Client
}
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil { if err != nil {
log.Errorf("ip6tables is not installed in the system or not supported: %v", err) log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else { } else {
if isIptablesClientAvailable(ipv6Client) {
m.ipv6Client = ipv6Client m.ipv6Client = ipv6Client
} }
}
if err := m.Reset(); err != nil { if err := m.Reset(); err != nil {
return nil, fmt.Errorf("failed to reset firewall: %v", err) return nil, fmt.Errorf("failed to reset firewall: %v", err)
@@ -73,6 +90,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment is empty rule ID is used as comment // If comment is empty rule ID is used as comment
@@ -83,6 +105,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
@@ -101,22 +124,45 @@ func (m *Manager) AddFiltering(
if sPort != nil && sPort.Values != nil { if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0]) sPortVal = strconv.Itoa(sPort.Values[0])
} }
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
ruleID := uuid.New().String() ruleID := uuid.New().String()
if comment == "" { if comment == "" {
comment = ruleID comment = ruleID
} }
specs := m.filterRuleSpecs( if ipsetName != "" {
"filter", rs, rsExists := m.rulesets[ipsetName]
ip, if !rsExists {
string(protocol), if err := ipset.Flush(ipsetName); err != nil {
sPortVal, log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
dPortVal, }
direction, if err := ipset.Create(ipsetName); err != nil {
action, return nil, fmt.Errorf("failed to create ipset: %w", err)
comment, }
) }
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
if rsExists {
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
rs.ips[ip.String()] = ruleID
return &Rule{
ruleID: ruleID,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}, nil
}
// this is new ipset so we need to create firewall rule for it
}
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
direction, action, comment, ipsetName)
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...) ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
@@ -144,12 +190,24 @@ func (m *Manager) AddFiltering(
} }
} }
return &Rule{ rule := &Rule{
id: ruleID, ruleID: ruleID,
specs: specs, specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT, dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil, v6: ip.To4() == nil,
}, nil }
if ipsetName != "" {
// ipset name is defined and it means that this rule was created
// for it, need to assosiate it with ruleset
m.rulesets[ipsetName] = ruleset{
rule: rule,
ips: map[string]string{rule.ip: ruleID},
}
}
return rule, nil
} }
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
@@ -170,6 +228,31 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
client = m.ipv6Client client = m.ipv6Client
} }
if rs, ok := m.rulesets[r.ipsetName]; ok {
// delete IP from ruleset IPs list and ipset
if _, ok := rs.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(rs.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(rs.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and assosiated firewall rule too
delete(m.rulesets, r.ipsetName)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
r = rs.rule
}
if r.dst { if r.dst {
return client.Delete("filter", ChainOutputFilterName, r.specs...) return client.Delete("filter", ChainOutputFilterName, r.specs...)
} }
@@ -193,6 +276,9 @@ func (m *Manager) Reset() error {
return nil return nil
} }
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// reset firewall chain, clear it and drop it // reset firewall chain, clear it and drop it
func (m *Manager) reset(client *iptables.IPTables, table string) error { func (m *Manager) reset(client *iptables.IPTables, table string) error {
ok, err := client.ChainExists(table, ChainInputFilterName) ok, err := client.ChainExists(table, ChainInputFilterName)
@@ -233,6 +319,16 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
return nil return nil
} }
for ipsetName := range m.rulesets {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
delete(m.rulesets, ipsetName)
}
return nil return nil
} }
@@ -240,6 +336,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
func (m *Manager) filterRuleSpecs( func (m *Manager) filterRuleSpecs(
table string, ip net.IP, protocol string, sPort, dPort string, table string, ip net.IP, protocol string, sPort, dPort string,
direction fw.RuleDirection, action fw.Action, comment string, direction fw.RuleDirection, action fw.Action, comment string,
ipsetName string,
) (specs []string) { ) (specs []string) {
matchByIP := true matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0 // don't use IP matching if IP is ip 0.0.0.0
@@ -249,13 +346,21 @@ func (m *Manager) filterRuleSpecs(
switch direction { switch direction {
case fw.RuleDirectionIN: case fw.RuleDirectionIN:
if matchByIP { if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String()) specs = append(specs, "-s", ip.String())
} }
}
case fw.RuleDirectionOUT: case fw.RuleDirectionOUT:
if matchByIP { if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String()) specs = append(specs, "-d", ip.String())
} }
} }
}
if protocol != "all" { if protocol != "all" {
specs = append(specs, "-p", protocol) specs = append(specs, "-p", protocol)
} }
@@ -335,3 +440,16 @@ func (m *Manager) actionToStr(action fw.Action) string {
} }
return "DROP" return "DROP"
} }
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
if ipsetName == "" {
return ""
} else if sPort != "" && dPort != "" {
return ipsetName + "-sport-dport"
} else if sPort != "" {
return ipsetName + "-sport"
} else if dPort != "" {
return ipsetName + "-dport"
}
return ipsetName
}

View File

@@ -55,12 +55,13 @@ func TestIptablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(); err != nil { err := manager.Reset()
t.Errorf("clear the manager state: %v", err) require.NoError(t, err, "clear the manager state")
}
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -68,7 +69,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) { t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
@@ -81,33 +82,31 @@ func TestIptablesManager(t *testing.T) {
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range") ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
}) })
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
if err := manager.DeleteRule(rule1); err != nil { err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
}
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
}) })
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
if err := manager.DeleteRule(rule2); err != nil { err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
}
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...) require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic") _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset() err = manager.Reset()
@@ -122,6 +121,88 @@ func TestIptablesManager(t *testing.T) {
}) })
} }
func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
}()
var rule1 fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
})
var rule2 fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{443},
}
rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
require.NoError(t, err, "failed to add rule")
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
})
t.Run("delete first rule", func(t *testing.T) {
err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
})
t.Run("delete second rule", func(t *testing.T) {
err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
err = manager.Reset()
require.NoError(t, err, "failed to reset")
})
}
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) { func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
exists, err := ipv4Client.Exists("filter", chainName, rulespec...) exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
require.NoError(t, err, "failed to check rule") require.NoError(t, err, "failed to check rule")
@@ -153,9 +234,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(); err != nil { err := manager.Reset()
t.Errorf("clear the manager state: %v", err) require.NoError(t, err, "clear the manager state")
}
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -167,9 +248,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -2,13 +2,16 @@ package iptables
// Rule to handle management of rules // Rule to handle management of rules
type Rule struct { type Rule struct {
id string ruleID string
ipsetName string
specs []string specs []string
ip string
dst bool dst bool
v6 bool v6 bool
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) GetRuleID() string {
return r.id return r.ruleID
} }

View File

@@ -6,12 +6,14 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/google/uuid" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall" fw "github.com/netbirdio/netbird/client/firewall"
@@ -29,11 +31,14 @@ const (
FilterOutputChainName = "netbird-acl-output-filter" FilterOutputChainName = "netbird-acl-output-filter"
) )
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
conn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn
tableIPv4 *nftables.Table tableIPv4 *nftables.Table
tableIPv6 *nftables.Table tableIPv6 *nftables.Table
@@ -43,6 +48,10 @@ type Manager struct {
filterInputChainIPv6 *nftables.Chain filterInputChainIPv6 *nftables.Chain
filterOutputChainIPv6 *nftables.Chain filterOutputChainIPv6 *nftables.Chain
rulesetManager *rulesetManager
setRemovedIPs map[string]struct{}
setRemoved map[string]*nftables.Set
wgIface iFaceMapper wgIface iFaceMapper
} }
@@ -54,8 +63,23 @@ type iFaceMapper interface {
// Create nftables firewall manager // Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
m := &Manager{ m := &Manager{
conn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn,
rulesetManager: newRuleManager(),
setRemovedIPs: map[string]struct{}{},
setRemoved: map[string]*nftables.Set{},
wgIface: wgIface, wgIface: wgIface,
} }
@@ -77,6 +101,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
@@ -84,6 +109,7 @@ func (m *Manager) AddFiltering(
var ( var (
err error err error
ipset *nftables.Set
table *nftables.Table table *nftables.Table
chain *nftables.Chain chain *nftables.Chain
) )
@@ -107,6 +133,46 @@ func (m *Manager) AddFiltering(
return nil, err return nil, err
} }
rawIP := ip.To4()
if rawIP == nil {
rawIP = ip.To16()
}
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
if ipsetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
if !isSetNew {
// if we already have nftables rules with set for given direction
// just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
}
}
ifaceKey := expr.MetaKeyIIFNAME ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME ifaceKey = expr.MetaKeyOIFNAME
@@ -146,39 +212,47 @@ func (m *Manager) AddFiltering(
}) })
} }
// don't use IP matching if IP is ip 0.0.0.0 // check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
if s := ip.String(); s != "0.0.0.0" && s != "::" { // in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position // source address position
var adrLen, adrOffset uint32 addrLen := uint32(len(rawIP))
if ip.To4() == nil { addrOffset := uint32(12)
adrLen = 16 if addrLen == 16 {
adrOffset = 8 addrOffset = 8
} else {
adrLen = 4
adrOffset = 12
} }
// change to destination address position if need // change to destination address position if need
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
adrOffset += adrLen addrOffset += addrLen
} }
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: adrOffset, Offset: addrOffset,
Len: adrLen, Len: addrLen,
}, },
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: add.AsSlice(), Data: rawIP,
}, },
) )
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetID: ipset.ID,
},
)
}
} }
if sPort != nil && len(sPort.Values) != 0 { if sPort != nil && len(sPort.Values) != 0 {
@@ -219,39 +293,76 @@ func (m *Manager) AddFiltering(
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
id := uuid.New().String() userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
userData := []byte(strings.Join([]string{id, comment}, " "))
_ = m.conn.InsertRule(&nftables.Rule{ rule := m.rConn.InsertRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Position: 0, Position: 0,
Exprs: expressions, Exprs: expressions,
UserData: userData, UserData: userData,
}) })
if err := m.rConn.Flush(); err != nil {
if err := m.conn.Flush(); err != nil { return nil, fmt.Errorf("flush insert rule: %v", err)
return nil, err
} }
list, err := m.conn.GetRules(table, chain) ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
if err != nil { return m.rulesetManager.addRule(ruleset, rawIP)
return nil, err
} }
// Add the rule to the chain // getRulesetID returns ruleset ID based on given parameters
rule := &Rule{id: id} func (m *Manager) getRulesetID(
for _, r := range list { ip net.IP,
if bytes.Equal(r.UserData, userData) { proto fw.Protocol,
rule.Rule = r sPort *fw.Port,
break dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil {
rulesetID += sPort.String()
} }
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
} }
if rule.Rule == nil { rulesetID += ":"
return nil, fmt.Errorf("rule not found") rulesetID += strconv.Itoa(int(action))
if ipsetName == "" {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipsetName + rulesetID
} }
return rule, nil // createSet in given table by name
func (m *Manager) createSet(
table *nftables.Table,
rawIP []byte,
name string,
) (*nftables.Set, error) {
keyType := nftables.TypeIPAddr
if len(rawIP) == 16 {
keyType = nftables.TypeIP6Addr
}
// else we create new ipset and continue creating rule
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: keyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
} }
// chain returns the chain for the given IP address with specific settings // chain returns the chain for the given IP address with specific settings
@@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
} }
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.conn.ListTablesOfFamily(family) tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
@@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
} }
} }
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return table, nil
} }
func (m *Manager) createChainIfNotExists( func (m *Manager) createChainIfNotExists(
@@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists(
return nil, err return nil, err
} }
chains, err := m.conn.ListChainsOfTableFamily(family) chains, err := m.rConn.ListChainsOfTableFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of chains: %w", err) return nil, fmt.Errorf("list of chains: %w", err)
} }
@@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists(
Policy: &polAccept, Policy: &polAccept,
} }
chain = m.conn.AddChain(chain) chain = m.rConn.AddChain(chain)
ifaceKey := expr.MetaKeyIIFNAME ifaceKey := expr.MetaKeyIIFNAME
shiftDSTAddr := 0 shiftDSTAddr := 0
@@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists(
) )
} }
_ = m.conn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Exprs: expressions, Exprs: expressions,
@@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists(
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
_ = m.conn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Exprs: expressions, Exprs: expressions,
}) })
if err := m.conn.Flush(); err != nil {
if err := m.rConn.Flush(); err != nil {
return nil, err return nil, err
} }
@@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists(
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error { func (m *Manager) DeleteRule(rule fw.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
nativeRule, ok := rule.(*Rule) nativeRule, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
if err := m.conn.DelRule(nativeRule.Rule); err != nil { if nativeRule.nftRule == nil {
return err return nil
} }
return m.conn.Flush() if nativeRule.nftSet != nil {
// call twice of delete set element raises error
// so we need to check if element is already removed
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
if _, ok := m.setRemovedIPs[key]; !ok {
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
if err != nil {
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
return err
}
m.setRemovedIPs[key] = struct{}{}
}
}
if m.rulesetManager.deleteRule(nativeRule) {
// deleteRule indicates that we still have IP in the ruleset
// it means we should not remove the nftables rule but need to update set
// so we prepare IP to be removed from set on the next flush call
return nil
}
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return err
}
nativeRule.nftRule = nil
if nativeRule.nftSet != nil {
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
}
nativeRule.nftSet = nil
}
return nil
} }
// Reset firewall to the default state // Reset firewall to the default state
@@ -475,27 +633,116 @@ func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
chains, err := m.conn.ListChains() chains, err := m.rConn.ListChains()
if err != nil { if err != nil {
return fmt.Errorf("list of chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
for _, c := range chains { for _, c := range chains {
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.conn.DelChain(c) m.rConn.DelChain(c)
} }
} }
tables, err := m.conn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list of tables: %w", err) return fmt.Errorf("list of tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == FilterTableName { if t.Name == FilterTableName {
m.conn.DelTable(t) m.rConn.DelTable(t)
} }
} }
return m.conn.Flush() return m.rConn.Flush()
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.flushWithBackoff(); err != nil {
return err
}
// set must be removed after flush rule changes
// otherwise we will get error
for _, s := range m.setRemoved {
m.rConn.FlushSet(s)
m.rConn.DelSet(s)
}
if len(m.setRemoved) > 0 {
if err := m.flushWithBackoff(); err != nil {
return err
}
}
m.setRemovedIPs = map[string]struct{}{}
m.setRemoved = map[string]*nftables.Set{}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime = backoffTime * 2
continue
}
break
}
return
}
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
if table == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(table, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) != 0 {
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
log.Errorf("failed to set rule handle: %v", err)
}
}
}
return nil
} }
func encodePort(port fw.Port) []byte { func encodePort(port fw.Port) []byte {

View File

@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset() err = manager.Reset()
@@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
fw.RuleDirectionIN, fw.RuleDirectionIN,
fw.ActionDrop, fw.ActionDrop,
"", "",
"",
) )
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// test expectations: // test expectations:
// 1) regular rule // 1) regular rule
// 2) "accept extra routed traffic rule" for the interface // 2) "accept extra routed traffic rule" for the interface
@@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
err = manager.DeleteRule(rule) err = manager.DeleteRule(rule)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// test expectations: // test expectations:
@@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(); err != nil {
@@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
err = manager.Flush()
require.NoError(t, err, "failed to flush")
}
} }
require.NoError(t, err, "failed to add rule")
}
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax)) t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
}) })
} }

View File

@@ -6,11 +6,14 @@ import (
// Rule to handle management of rules // Rule to handle management of rules
type Rule struct { type Rule struct {
*nftables.Rule nftRule *nftables.Rule
id string nftSet *nftables.Set
ruleID string
ip []byte
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) GetRuleID() string {
return r.id return r.ruleID
} }

View File

@@ -0,0 +1,115 @@
package nftables
import (
"bytes"
"fmt"
"github.com/google/nftables"
"github.com/rs/xid"
)
// nftRuleset links native firewall rule and ipset to ACL generated rules
type nftRuleset struct {
nftRule *nftables.Rule
nftSet *nftables.Set
issuedRules map[string]*Rule
rulesetID string
}
type rulesetManager struct {
rulesets map[string]*nftRuleset
nftSetName2rulesetID map[string]string
issuedRuleID2rulesetID map[string]string
}
func newRuleManager() *rulesetManager {
return &rulesetManager{
rulesets: map[string]*nftRuleset{},
nftSetName2rulesetID: map[string]string{},
issuedRuleID2rulesetID: map[string]string{},
}
}
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
ruleset, ok := r.rulesets[rulesetID]
return ruleset, ok
}
func (r *rulesetManager) createRuleset(
rulesetID string,
nftRule *nftables.Rule,
nftSet *nftables.Set,
) *nftRuleset {
ruleset := nftRuleset{
rulesetID: rulesetID,
nftRule: nftRule,
nftSet: nftSet,
issuedRules: map[string]*Rule{},
}
r.rulesets[ruleset.rulesetID] = &ruleset
if nftSet != nil {
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
}
return &ruleset
}
func (r *rulesetManager) addRule(
ruleset *nftRuleset,
ip []byte,
) (*Rule, error) {
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
return nil, fmt.Errorf("ruleset not found")
}
rule := Rule{
nftRule: ruleset.nftRule,
nftSet: ruleset.nftSet,
ruleID: xid.New().String(),
ip: ip,
}
ruleset.issuedRules[rule.ruleID] = &rule
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
return &rule, nil
}
// deleteRule from ruleset and returns true if contains other rules
func (r *rulesetManager) deleteRule(rule *Rule) bool {
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
if !ok {
return false
}
ruleset := r.rulesets[rulesetID]
if ruleset.nftRule == nil {
return false
}
delete(r.issuedRuleID2rulesetID, rule.ruleID)
delete(ruleset.issuedRules, rule.ruleID)
if len(ruleset.issuedRules) == 0 {
delete(r.rulesets, ruleset.rulesetID)
if rule.nftSet != nil {
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
}
return false
}
return true
}
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
//
// This is important to do, because after we add rule to the nftables we can't update it until
// we set correct handle value to it.
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
split := bytes.Split(nftRule.UserData, []byte(" "))
ruleset, ok := r.rulesets[string(split[0])]
if !ok {
return fmt.Errorf("ruleset not found")
}
*ruleset.nftRule = *nftRule
return nil
}

View File

@@ -0,0 +1,122 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
func TestRulesetManager_createRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{
UserData: []byte(rulesetID),
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
require.NotNil(t, ruleset, "createRuleset() failed")
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
}
func TestRulesetManager_addRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
require.NotEqual(t, rule.ruleID, "ruleID is empty")
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
ruleset2 := &nftRuleset{
rulesetID: "ruleset-2",
}
_, err = rulesetManager.addRule(ruleset2, ip)
require.Error(t, err, "addRule() should have failed")
}
func TestRulesetManager_deleteRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
ip2 := []byte("192.168.1.1")
rule2, err := rulesetManager.addRule(ruleset, ip2)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule2, "rule should not be nil")
hasNext := rulesetManager.deleteRule(rule)
require.True(t, hasNext, "deleteRule() should have returned true")
// Check that the rule is no longer in the manager.
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
hasNext = rulesetManager.deleteRule(rule2)
require.False(t, hasNext, "deleteRule() should have returned false")
}
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.0.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
nftRuleCopy := nftRule
nftRuleCopy.Handle = 2
nftRuleCopy.UserData = []byte(rulesetID)
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
require.NoError(t, err, "setNftRuleHandle() failed")
// check correct work with references
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
}
func TestRulesetManager_getRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
nftSet := nftables.Set{
ID: 2,
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
require.NotNil(t, ruleset, "createRuleset() failed")
find, ok := rulesetManager.getRuleset(rulesetID)
require.True(t, ok, "getRuleset() failed")
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
_, ok = rulesetManager.getRuleset("does-not-exist")
require.False(t, ok, "getRuleset() failed")
}

View File

@@ -1,5 +1,9 @@
package firewall package firewall
import (
"strconv"
)
// Protocol is the protocol of the port // Protocol is the protocol of the port
type Protocol string type Protocol string
@@ -28,3 +32,15 @@ type Port struct {
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports // Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
Values []int Values []int
} }
// String interface implementation
func (p *Port) String() string {
var ports string
for _, port := range p.Values {
if ports != "" {
ports += ","
}
ports += strconv.Itoa(port)
}
return ports
}

View File

@@ -21,11 +21,13 @@ type IFaceMapper interface {
SetFilter(iface.PacketFilter) error SetFilter(iface.PacketFilter) error
} }
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]Rule
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
outgoingRules []Rule outgoingRules map[string]RuleSet
incomingRules []Rule incomingRules map[string]RuleSet
rulesIndex map[string]int
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
@@ -48,7 +50,6 @@ type decoder struct {
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) { func Create(iface IFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
rulesIndex: make(map[string]int),
decoders: sync.Pool{ decoders: sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
return d return d
}, },
}, },
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
@@ -81,6 +84,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
r := Rule{ r := Rule{
@@ -124,15 +128,17 @@ func (m *Manager) AddFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
var p int
if direction == fw.RuleDirectionIN { if direction == fw.RuleDirectionIN {
m.incomingRules = append(m.incomingRules, r) if _, ok := m.incomingRules[r.ip.String()]; !ok {
p = len(m.incomingRules) - 1 m.incomingRules[r.ip.String()] = make(RuleSet)
} else { }
m.outgoingRules = append(m.outgoingRules, r) m.incomingRules[r.ip.String()][r.id] = r
p = len(m.outgoingRules) - 1 } else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
m.rulesIndex[r.id] = p
m.mutex.Unlock() m.mutex.Unlock()
return &r, nil return &r, nil
@@ -148,24 +154,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
p, ok := m.rulesIndex[r.id] if r.direction == fw.RuleDirectionIN {
_, ok := m.incomingRules[r.ip.String()][r.id]
if !ok { if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id) return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
} }
delete(m.rulesIndex, r.id) delete(m.incomingRules[r.ip.String()], r.id)
var toUpdate []Rule
if r.direction == fw.RuleDirectionIN {
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
toUpdate = m.incomingRules
} else { } else {
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...) _, ok := m.outgoingRules[r.ip.String()][r.id]
toUpdate = m.outgoingRules if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
} }
for i := 0; i < len(toUpdate); i++ {
m.rulesIndex[toUpdate[i].id] = i
}
return nil return nil
} }
@@ -174,13 +176,15 @@ func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = m.outgoingRules[:0] m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = m.incomingRules[:0] m.incomingRules = make(map[string]RuleSet)
m.rulesIndex = make(map[string]int)
return nil return nil
} }
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.dropFilter(packetData, m.outgoingRules, false) return m.dropFilter(packetData, m.outgoingRules, false)
@@ -192,7 +196,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
} }
// dropFilter imlements same logic for booth direction of the traffic // dropFilter imlements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool { func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@@ -224,37 +228,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
log.Errorf("unknown layer: %v", d.decoded[0]) log.Errorf("unknown layer: %v", d.decoded[0])
return true return true
} }
payloadLayer := d.decoded[1]
// check if IP address match by IP var ip net.IP
for _, rule := range rules {
if rule.matchByIP {
switch ipLayer { switch ipLayer {
case layers.LayerTypeIPv4: case layers.LayerTypeIPv4:
if isIncomingPacket { if isIncomingPacket {
if !d.ip4.SrcIP.Equal(rule.ip) { ip = d.ip4.SrcIP
continue
}
} else { } else {
if !d.ip4.DstIP.Equal(rule.ip) { ip = d.ip4.DstIP
continue
}
} }
case layers.LayerTypeIPv6: case layers.LayerTypeIPv6:
if isIncomingPacket { if isIncomingPacket {
if !d.ip6.SrcIP.Equal(rule.ip) { ip = d.ip6.SrcIP
continue
}
} else { } else {
if !d.ip6.DstIP.Equal(rule.ip) { ip = d.ip6.DstIP
continue
}
}
} }
} }
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["::"], d)
if ok {
return filter
}
// default policy is DROP ALL
return true
}
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) {
continue
}
if rule.protoLayer == layerTypeAll { if rule.protoLayer == layerTypeAll {
return rule.drop return rule.drop, true
} }
if payloadLayer != rule.protoLayer { if payloadLayer != rule.protoLayer {
@@ -264,38 +280,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
switch payloadLayer { switch payloadLayer {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
if rule.sPort == 0 && rule.dPort == 0 { if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop return rule.drop, true
} }
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) { if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
return rule.drop return rule.drop, true
} }
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
return rule.drop return rule.drop, true
} }
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule) // if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook // we ignore rule.drop and call this hook
if rule.udpHook != nil { if rule.udpHook != nil {
return rule.udpHook(packetData) return rule.udpHook(packetData), true
} }
if rule.sPort == 0 && rule.dPort == 0 { if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop return rule.drop, true
} }
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) { if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
return rule.drop return rule.drop, true
} }
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop return rule.drop, true
} }
return rule.drop return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop return rule.drop, true
} }
} }
return false, false
// default policy is DROP ALL
return true
} }
// SetNetwork of the wireguard interface to which filtering applied // SetNetwork of the wireguard interface to which filtering applied
@@ -325,19 +339,19 @@ func (m *Manager) AddUDPPacketHook(
} }
m.mutex.Lock() m.mutex.Lock()
var toUpdate []Rule
if in { if in {
r.direction = fw.RuleDirectionIN r.direction = fw.RuleDirectionIN
m.incomingRules = append([]Rule{r}, m.incomingRules...) if _, ok := m.incomingRules[r.ip.String()]; !ok {
toUpdate = m.incomingRules m.incomingRules[r.ip.String()] = make(map[string]Rule)
}
m.incomingRules[r.ip.String()][r.id] = r
} else { } else {
m.outgoingRules = append([]Rule{r}, m.outgoingRules...) if _, ok := m.outgoingRules[r.ip.String()]; !ok {
toUpdate = m.outgoingRules m.outgoingRules[r.ip.String()] = make(map[string]Rule)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
for i := range toUpdate {
m.rulesIndex[toUpdate[i].id] = i
}
m.mutex.Unlock() m.mutex.Unlock()
return r.id return r.id
@@ -345,15 +359,19 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID // RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error { func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range m.incomingRules { for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID { if r.id == hookID {
return m.DeleteRule(&r) return m.DeleteRule(&r)
} }
} }
for _, r := range m.outgoingRules { }
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID { if r.id == hookID {
return m.DeleteRule(&r) return m.DeleteRule(&r)
} }
} }
}
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")
} }

View File

@@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop action = fw.ActionDrop
comment = "Test rule 2" comment = "Test rule 2"
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 { if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the rulesIndex") t.Errorf("rule2 is not in the incomingRules")
} }
err = m.DeleteRule(rule2) err = m.DeleteRule(rule2)
@@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 { if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
t.Errorf("rule1 still in the rulesIndex") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager := &Manager{ manager := &Manager{
incomingRules: []Rule{}, incomingRules: map[string]RuleSet{},
outgoingRules: []Rule{}, outgoingRules: map[string]RuleSet{},
rulesIndex: make(map[string]int),
} }
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule var addedRule Rule
if tt.in { if tt.in {
if len(manager.incomingRules) != 1 { if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
addedRule = manager.incomingRules[0] for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule
}
} else { } else {
if len(manager.outgoingRules) != 1 { if len(manager.outgoingRules) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
addedRule = manager.outgoingRules[0] for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule
}
} }
if !tt.ip.Equal(addedRule.ip) { if !tt.ip.Equal(addedRule.ip) {
@@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected udpHook to be set") t.Errorf("expected udpHook to be set")
return return
} }
// Ensure rulesIndex is correctly updated
index, ok := manager.rulesIndex[addedRule.id]
if !ok {
t.Errorf("expected rule to be in rulesIndex")
return
}
if index != 0 {
t.Errorf("expected rule index to be 0, got %d", index)
return
}
}) })
} }
} }
@@ -244,7 +236,7 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment) _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
return return
} }
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 { if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
t.Errorf("rules is not empty") t.Errorf("rules is not empty")
} }
} }
@@ -282,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment) _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -346,12 +338,14 @@ func TestRemovePacketHook(t *testing.T) {
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
for _, rule := range manager.outgoingRules { for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID { if rule.id == hookID {
found = true found = true
break break
} }
} }
}
if !found { if !found {
t.Fatalf("The hook was not added properly.") t.Fatalf("The hook was not added properly.")
@@ -364,12 +358,14 @@ func TestRemovePacketHook(t *testing.T) {
} }
// Assert the hook is removed by checking it in the manager's outgoing rules // Assert the hook is removed by checking it in the manager's outgoing rules
for _, rule := range manager.outgoingRules { for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID { if rule.id == hookID {
t.Fatalf("The hook was not removed properly.") t.Fatalf("The hook was not removed properly.")
} }
} }
} }
}
func TestUSPFilterCreatePerformance(t *testing.T) { func TestUSPFilterCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
@@ -394,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -1,10 +1,13 @@
package acl package acl
import ( import (
"crypto/md5"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -31,10 +34,23 @@ type Manager interface {
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
manager firewall.Manager manager firewall.Manager
ipsetCounter int
rulesPairs map[string][]firewall.Rule rulesPairs map[string][]firewall.Rule
mutex sync.Mutex mutex sync.Mutex
} }
type ipsetInfo struct {
name string
ipCount int
}
func newDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}
}
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
// //
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains. // If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
@@ -42,11 +58,28 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
start := time.Now()
defer func() {
total := 0
for _, pairs := range d.rulesPairs {
total += len(pairs)
}
log.Infof(
"ACL rules processed in: %v, total rules count: %d",
time.Since(start), total)
}()
if d.manager == nil { if d.manager == nil {
log.Debug("firewall manager is not supported, skipping firewall rules") log.Debug("firewall manager is not supported, skipping firewall rules")
return return
} }
defer func() {
if err := d.manager.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
rules, squashedProtocols := d.squashAcceptRules(networkMap) rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := (networkMap.PeerConfig != nil && enableSSH := (networkMap.PeerConfig != nil &&
@@ -94,14 +127,38 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
applyFailed := false applyFailed := false
newRulePairs := make(map[string][]firewall.Rule) newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
// calculate which IP's can be grouped in by which ipset
// to do that we use rule selector (which is just rule properties without IP's)
for _, r := range rules { for _, r := range rules {
rulePair, err := d.protoRuleToFirewallRule(r) selector := d.getRuleGroupingSelector(r)
ipset, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset = &ipsetInfo{}
}
ipset.ipCount++
ipsetByRuleSelectors[selector] = ipset
}
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
ipsetName := ""
if ipset.name == "" {
d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
}
ipsetName = ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil { if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err) log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
applyFailed = true applyFailed = true
break break
} }
newRulePairs[rulePair[0].GetRuleID()] = rulePair newRulePairs[pairID] = rulePair
} }
if applyFailed { if applyFailed {
log.Error("failed to apply firewall rules, rollback ACL to previous state") log.Error("failed to apply firewall rules, rollback ACL to previous state")
@@ -140,55 +197,71 @@ func (d *DefaultManager) Stop() {
} }
} }
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) ([]firewall.Rule, error) { func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (string, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP) ip := net.ParseIP(r.PeerIP)
if ip == nil { if ip == nil {
return nil, fmt.Errorf("invalid IP address, skipping firewall rule") return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
} }
protocol := convertToFirewallProtocol(r.Protocol) protocol := convertToFirewallProtocol(r.Protocol)
if protocol == firewall.ProtocolUnknown { if protocol == firewall.ProtocolUnknown {
return nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol) return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
} }
action := convertFirewallAction(r.Action) action := convertFirewallAction(r.Action)
if action == firewall.ActionUnknown { if action == firewall.ActionUnknown {
return nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action) return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
} }
var port *firewall.Port var port *firewall.Port
if r.Port != "" { if r.Port != "" {
value, err := strconv.Atoi(r.Port) value, err := strconv.Atoi(r.Port)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid port, skipping firewall rule") return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
} }
port = &firewall.Port{ port = &firewall.Port{
Values: []int{value}, Values: []int{value},
} }
} }
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule var rules []firewall.Rule
var err error var err error
switch r.Direction { switch r.Direction {
case mgmProto.FirewallRule_IN: case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.FirewallRule_OUT: case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, "") rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default: default:
return nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
} }
if err != nil { if err != nil {
return nil, err return "", nil, err
} }
d.rulesPairs[rules[0].GetRuleID()] = rules d.rulesPairs[ruleID] = rules
return rules, nil return ruleID, rules, nil
} }
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { func (d *DefaultManager) addInRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment) rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@@ -198,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
return rules, nil return rules, nil
} }
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment) rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@@ -206,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
return append(rules, rule), nil return append(rules, rule), nil
} }
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { func (d *DefaultManager) addOutRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment) rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@@ -218,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
return rules, nil return rules, nil
} }
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment) rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@@ -226,6 +309,23 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
return append(rules, rule), nil return append(rules, rule), nil
} }
// getRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getRuleID(
ip net.IP,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
comment string,
) string {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil {
idStr += port.String()
}
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic. // to all peers in the network map to one rule which just accepts that type of the traffic.
// //
@@ -235,7 +335,7 @@ func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap, networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) { ) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
totalIPs := 0 totalIPs := 0
for _, p := range networkMap.RemotePeers { for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps { for range p.AllowedIps {
totalIPs++ totalIPs++
} }
@@ -246,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules(
in := protoMatch{} in := protoMatch{}
out := protoMatch{} out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not. // this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list. // We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if: // But we zeroed the IP's for protocol if:
@@ -262,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules(
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = map[string]int{} protocols[r.Protocol] = map[string]int{}
} }
match := protocols[r.Protocol]
if _, ok := match[r.PeerIP]; ok { // special case, when we recieve this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return return
} }
match[r.PeerIP] = i
ipset := protocols[r.Protocol]
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
} }
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
@@ -288,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules(
mgmProto.FirewallRule_UDP, mgmProto.FirewallRule_UDP,
} }
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) { squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
for _, protocol := range protocolOrders { for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
@@ -346,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules(
return append(rules, squashedRules...), squashedProtocols return append(rules, squashedRules...), squashedProtocols
} }
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol { func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
switch protocol { switch protocol {
case mgmProto.FirewallRule_TCP: case mgmProto.FirewallRule_TCP:

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"runtime" "runtime"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
) )
@@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &DefaultManager{ return newDefaultManager(fm), nil
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}, nil
} }
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }

View File

@@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
} }
} }
return &DefaultManager{ return newDefaultManager(fm), nil
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}, nil
} }

View File

@@ -55,6 +55,11 @@ func TestDefaultManager(t *testing.T) {
}) })
t.Run("add extra rules", func(t *testing.T) { t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{}
for id := range acl.rulesPairs {
existedPairs[id] = struct{}{}
}
// remove first rule // remove first rule
networkMap.FirewallRules = networkMap.FirewallRules[1:] networkMap.FirewallRules = networkMap.FirewallRules[1:]
networkMap.FirewallRules = append( networkMap.FirewallRules = append(
@@ -67,11 +72,6 @@ func TestDefaultManager(t *testing.T) {
}, },
) )
existedRulesID := map[string]struct{}{}
for id := range acl.rulesPairs {
existedRulesID[id] = struct{}{}
}
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
// we should have one old and one new rule in the existed rules // we should have one old and one new rule in the existed rules
@@ -80,13 +80,16 @@ func TestDefaultManager(t *testing.T) {
return return
} }
// check that old rules was removed // check that old rule was removed
for id := range existedRulesID { previousCount := 0
if _, ok := acl.rulesPairs[id]; ok { for id := range acl.rulesPairs {
t.Errorf("old rule was not removed") if _, ok := existedPairs[id]; ok {
return previousCount++
} }
} }
if previousCount != 1 {
t.Errorf("old rule was not removed")
}
}) })
t.Run("handle default rules", func(t *testing.T) { t.Run("handle default rules", func(t *testing.T) {

View File

@@ -215,11 +215,13 @@ func update(input ConfigInput) (*Config, error) {
} }
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey { if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
log.Infof("new pre-shared key provided, updated to %s (old value %s)", if *input.PreSharedKey != "" {
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
*input.PreSharedKey, config.PreSharedKey) *input.PreSharedKey, config.PreSharedKey)
config.PreSharedKey = *input.PreSharedKey config.PreSharedKey = *input.PreSharedKey
refresh = true refresh = true
} }
}
if config.SSHKey == "" { if config.SSHKey == "" {
pem, err := ssh.GeneratePrivateKey(ssh.ED25519) pem, err := ssh.GeneratePrivateKey(ssh.ED25519)

View File

@@ -63,7 +63,22 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), managementURL) assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey) assert.Equal(t, config.PreSharedKey, preSharedKey)
// case 4: existing config, but new managementURL has been provided -> update config // case 4: new empty pre-shared key config -> fetch it
newPreSharedKey := ""
config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: path,
PreSharedKey: &newPreSharedKey,
})
if err != nil {
return
}
assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey)
// case 5: existing config, but new managementURL has been provided -> update config
newManagementURL := "https://test.newManagement.url:33071" newManagementURL := "https://test.newManagement.url:33071"
config, err = UpdateOrCreateConfig(ConfigInput{ config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: newManagementURL, ManagementURL: newManagementURL,

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -24,7 +25,24 @@ import (
) )
// RunClient with main logic. // RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error { func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{})
}
// RunClientMobile with main logic on mobile system
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return wrapErr(err) return wrapErr(err)
} }
// in case of non Android os these variables will be nil engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
err = engine.Start() err = engine.Start()
if err != nil { if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)

View File

@@ -1,13 +1,9 @@
package dns package dns
import (
"github.com/netbirdio/netbird/iface"
)
type androidHostManager struct { type androidHostManager struct {
} }
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
return &androidHostManager{}, nil return &androidHostManager{}, nil
} }

View File

@@ -9,8 +9,6 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -34,7 +32,7 @@ type systemConfigurator struct {
createdKeys map[string]struct{} createdKeys map[string]struct{}
} }
func newHostManager(_ *iface.WGIface) (hostManager, error) { func newHostManager(_ WGIface) (hostManager, error) {
return &systemConfigurator{ return &systemConfigurator{
createdKeys: make(map[string]struct{}), createdKeys: make(map[string]struct{}),
}, nil }, nil

View File

@@ -5,10 +5,10 @@ package dns
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os" "os"
"strings" "strings"
log "github.com/sirupsen/logrus"
) )
const ( const (
@@ -25,7 +25,7 @@ const (
type osManagerType int type osManagerType int
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -6,8 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -33,7 +31,7 @@ type registryConfigurator struct {
existingSearchDomains []string existingSearchDomains []string
} }
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
guid, err := wgInterface.GetInterfaceGUIDString() guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
return "" return ""
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
//TODO implement me
panic("implement me")
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface // UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil { if m.UpdateDNSServerFunc != nil {

View File

@@ -14,8 +14,6 @@ import (
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
} }
} }
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -8,8 +8,6 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const resolvconfCommand = "resolvconf" const resolvconfCommand = "resolvconf"
@@ -18,7 +16,7 @@ type resolvconf struct {
ifaceName string ifaceName string
} }
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
return &resolvconf{ return &resolvconf{
ifaceName: wgInterface.Name(), ifaceName: wgInterface.Name(),
}, nil }, nil

View File

@@ -3,29 +3,20 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"math/big"
"net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
) )
const ( // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
defaultPort = 53 type ReadyListener interface {
customPort = 5053 OnReady()
defaultIP = "127.0.0.1" }
customIP = "127.0.0.153"
)
// Server is a dns server interface // Server is a dns server interface
type Server interface { type Server interface {
@@ -33,6 +24,7 @@ type Server interface {
Stop() Stop()
DnsIP() string DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string)
} }
type registeredHandlerMap map[string]handlerWithStop type registeredHandlerMap map[string]handlerWithStop
@@ -42,21 +34,19 @@ type DefaultServer struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
mux sync.Mutex mux sync.Mutex
fakeResolverWG sync.WaitGroup service service
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registeredHandlerMap dnsMuxMap registeredHandlerMap
localResolver *localResolver localResolver *localResolver
wgInterface *iface.WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
updateSerial uint64 updateSerial uint64
listenerIsRunning bool
runtimePort int
runtimeIP string
previousConfigHash uint64 previousConfigHash uint64
currentConfig hostDNSConfig currentConfig hostDNSConfig
customAddress *netip.AddrPort
enabled bool // permanent related properties
permanent bool
hostsDnsList []string
hostsDnsListLock sync.Mutex
} }
type handlerWithStop interface { type handlerWithStop interface {
@@ -70,9 +60,7 @@ type muxUpdate struct {
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) { func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
mux := dns.NewServeMux()
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -82,34 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
addrPort = &parsedAddrPort addrPort = &parsedAddrPort
} }
ctx, stop := context.WithCancel(ctx) var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
setServerDns(ds)
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
server: &dns.Server{ service: dnsService,
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
dnsMux: mux,
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
customAddress: addrPort,
} }
if initialDnsCfg != nil { return defaultServer
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
} }
defaultServer.evalRuntimeAddress() // Initialize instantiate host manager and the dns service
return defaultServer, nil
}
// Initialize instantiate host manager. It required to be initialized wginterface
func (s *DefaultServer) Initialize() (err error) { func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -118,74 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
return nil return nil
} }
if s.permanent {
err = s.service.Listen()
if err != nil {
return err
}
}
s.hostManager, err = newHostManager(s.wgInterface) s.hostManager, err = newHostManager(s.wgInterface)
return return
} }
// listen runs the listener in a go routine // DnsIP returns the DNS resolver server IP address
func (s *DefaultServer) listen() { //
// nil check required in unit tests // When kernel space interface used it return real DNS server listener IP address
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { // For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
s.fakeResolverWG.Add(1)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
hookID := s.filterDNSTraffic()
s.fakeResolverWG.Wait()
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}()
return
}
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
}
func (s *DefaultServer) DnsIP() string { func (s *DefaultServer) DnsIP() string {
if !s.enabled { return s.service.RuntimeIP()
return ""
}
return s.runtimeIP
}
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" && s.wgInterface != nil {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
} }
// Stop stops the server // Stop stops the server
@@ -194,34 +141,30 @@ func (s *DefaultServer) Stop() {
defer s.mux.Unlock() defer s.mux.Unlock()
s.ctxCancel() s.ctxCancel()
if s.hostManager != nil {
err := s.hostManager.restoreHostDNS() err := s.hostManager.restoreHostDNS()
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
s.fakeResolverWG.Done()
} }
err = s.stopListener() s.service.Stop()
if err != nil {
log.Error(err)
}
} }
func (s *DefaultServer) stopListener() error { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
if !s.listenerIsRunning { // It will be applied if the mgm server do not enforce DNS settings for root zone
return nil func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
} s.hostsDnsListLock.Lock()
defer s.hostsDnsListLock.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) s.hostsDnsList = hostsDnsList
defer cancel() _, ok := s.dnsMuxMap[nbdns.RootZone]
if ok {
err := s.server.ShutdownContext(ctx) log.Debugf("on new host DNS config but skip to apply it")
if err != nil { return
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
} }
return nil log.Debugf("update host DNS settings: %+v", hostsDnsList)
s.addHostRootZone()
} }
// UpdateDNSServer processes an update received from the management service // UpdateDNSServer processes an update received from the management service
@@ -272,16 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener or fake resolver // is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records // and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable { if update.ServiceEnable {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { _ = s.service.Listen()
s.fakeResolverWG.Done() } else if !s.permanent {
} else { s.service.Stop()
if err := s.stopListener(); err != nil {
log.Error(err)
}
}
} else if !s.listenerIsRunning {
s.listen()
} }
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
@@ -292,15 +229,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
if err != nil { if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err) return fmt.Errorf("not applying dns update, error: %v", err)
} }
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates) s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords) s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
hostUpdate := s.currentConfig hostUpdate := s.currentConfig
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver") "Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
hostUpdate.routeAll = false hostUpdate.routeAll = false
@@ -405,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registeredHandlerMap) muxUpdateMap := make(registeredHandlerMap)
var isContainRootUpdate bool
for _, update := range muxUpdates { for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler) s.service.RegisterMux(update.domain, update.handler)
muxUpdateMap[update.domain] = update.handler muxUpdateMap[update.domain] = update.handler
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop() existingHandler.stop()
} }
if update.domain == nbdns.RootZone {
isContainRootUpdate = true
}
} }
for key, existingHandler := range s.dnsMuxMap { for key, existingHandler := range s.dnsMuxMap {
_, found := muxUpdateMap[key] _, found := muxUpdateMap[key]
if !found { if !found {
if !isContainRootUpdate && key == nbdns.RootZone {
s.hostsDnsListLock.Lock()
s.addHostRootZone()
s.hostsDnsListLock.Unlock()
existingHandler.stop() existingHandler.stop()
s.deregisterMux(key) } else {
existingHandler.stop()
s.service.DeregisterMux(key)
}
} }
} }
@@ -448,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
} }
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate // upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to // the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate. // reactivate it. Not allowed to call reactivate before deactivate.
@@ -483,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
for i, item := range s.currentConfig.domains { for i, item := range s.currentConfig.domains {
if _, found := removeIndex[item.domain]; found { if _, found := removeIndex[item.domain]; found {
s.currentConfig.domains[i].disabled = true s.currentConfig.domains[i].disabled = true
s.deregisterMux(item.domain) s.service.DeregisterMux(item.domain)
removeIndex[item.domain] = i removeIndex[item.domain] = i
} }
} }
@@ -500,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue continue
} }
s.currentConfig.domains[i].disabled = false s.currentConfig.domains[i].disabled = false
s.registerMux(domain, handler) s.service.RegisterMux(domain, handler)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
@@ -516,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
return return
} }
func (s *DefaultServer) filterDNSTraffic() string { func (s *DefaultServer) addHostRootZone() {
filter := s.wgInterface.GetFilter() handler := newUpstreamResolver(s.ctx)
if filter == nil { handler.upstreamServers = make([]string, len(s.hostsDnsList))
log.Error("can't set DNS filter, filter not initialized") for n, ua := range s.hostsDnsList {
return "" handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
} }
handler.deactivate = func() {}
firstLayerDecoder := layers.LayerTypeIPv4 handler.reactivate = func() {}
if s.wgInterface.Address().Network.IP.To4() == nil { s.service.RegisterMux(nbdns.RootZone, handler)
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}
func (s *DefaultServer) evalRuntimeAddress() {
defer func() {
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}()
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
return
}
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
return
}
ip, port, err := s.getFirstListenerAvailable()
if err != nil {
log.Error(err)
return
}
s.runtimeIP = ip
s.runtimePort = port
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}
func hasValidDnsServer(cfg *nbdns.Config) bool {
for _, c := range cfg.NameServerGroups {
if c.Primary {
return true
}
}
return false
} }

View File

@@ -0,0 +1,29 @@
package dns
import (
"fmt"
"sync"
)
var (
mutex sync.Mutex
server Server
)
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
func GetServerDns() (Server, error) {
mutex.Lock()
if server == nil {
mutex.Unlock()
return nil, fmt.Errorf("DNS server not instantiated yet")
}
s := server
mutex.Unlock()
return s, nil
}
func setServerDns(newServerServer Server) {
mutex.Lock()
server = newServerServer
defer mutex.Unlock()
}

View File

@@ -0,0 +1,24 @@
package dns
import (
"testing"
)
func TestGetServerDns(t *testing.T) {
_, err := GetServerDns()
if err == nil {
t.Errorf("invalid dns server instance")
}
srv := &MockServer{}
setServerDns(srv)
srvB, err := GetServerDns()
if err != nil {
t.Errorf("invalid dns server instance: %s", err)
}
if srvB != srv {
t.Errorf("missmatch dns instances")
}
}

View File

@@ -5,17 +5,59 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/miekg/dns" "github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
pfmock "github.com/netbirdio/netbird/iface/mocks"
) )
type mocWGIface struct {
filter iface.PacketFilter
}
func (w *mocWGIface) Name() string {
panic("implement me")
}
func (w *mocWGIface) Address() iface.WGAddress {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{
IP: ip,
Network: network,
}
}
func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter
}
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
panic("implement me")
}
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
panic("implement me")
}
func (w *mocWGIface) IsUserspaceBind() bool {
return false
}
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
w.filter = filter
return nil
}
var zoneRecords = []nbdns.SimpleRecord{ var zoneRecords = []nbdns.SimpleRecord{
{ {
Name: "peera.netbird.cloud", Name: "peera.netbird.cloud",
@@ -26,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
}, },
} }
func init() {
log.SetLevel(log.TraceLevel)
formatter.SetTextFormatter(log.StandardLogger())
}
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
{ {
@@ -221,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -239,9 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
dnsServer.fakeResolverWG.Add(1)
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil { if err != nil {
@@ -276,6 +320,133 @@ func TestUpdateDNSServer(t *testing.T) {
} }
} }
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("crate and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}
func TestDNSServerStartStop(t *testing.T) { func TestDNSServerStartStop(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@@ -292,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort) dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
if err != nil {
dnsServer.hostManager = newNoopHostMocker() t.Fatalf("%v", err)
dnsServer.listen()
time.Sleep(100 * time.Millisecond)
if !dnsServer.listenerIsRunning {
t.Fatal("dns server listener is not running")
} }
dnsServer.hostManager = newNoopHostMocker()
err = dnsServer.service.Listen()
if err != nil {
t.Fatalf("dns server is not running: %s", err)
}
time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop() defer dnsServer.Stop()
err := dnsServer.localResolver.registerRecord(zoneRecords[0]) err = dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver) dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{ resolver := &net.Resolver{
PreferGo: true, PreferGo: true,
@@ -314,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
d := net.Dialer{ d := net.Dialer{
Timeout: time.Second * 5, Timeout: time.Second * 5,
} }
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort) addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
conn, err := d.DialContext(ctx, network, addr) conn, err := d.DialContext(ctx, network, addr)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
@@ -349,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
dnsMux: dns.DefaultServeMux, service: newServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
@@ -412,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
} }
} }
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer { func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
mux := dns.NewServeMux() wgIFace, err := createWgInterfaceWithBind(t)
var parsedAddrPort *netip.AddrPort
if addrPort != "" {
parsed, err := netip.ParseAddrPort(addrPort)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal("failed to initialize wg interface")
}
parsedAddrPort = &parsed
} }
defer wgIFace.Close()
dnsServer := &dns.Server{ var dnsList []string
Net: "udp", dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
Handler: mux, err = dnsServer.Initialize()
UDPSize: 65535,
}
ctx, cancel := context.WithCancel(context.TODO())
ds := &DefaultServer{
ctx: ctx,
ctxCancel: cancel,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
customAddress: parsedAddrPort,
}
ds.evalRuntimeAddress()
return ds
}
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil { if err != nil {
t.Errorf("Error parsing CIDR: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
return return
} }
defer dnsServer.Stop()
lastIP := getLastIPFromNetwork(ipnet, 1) dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
} }
} }
func TestDNSPermanent_updateUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Enabled: true,
Primary: true,
},
},
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
update2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{},
}
err = dnsServer.UpdateDNSServer(2, update2)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
}
func TestDNSPermanent_matchOnly(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Domains: []string{"customdomain.com"},
Primary: false,
},
},
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, nil
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err
}
err = wgIface.Create()
if err != nil {
t.Fatalf("crate and init wireguard interface: %v", err)
return nil, err
}
pf, err := uspfilter.Create(wgIface)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err
}
err = wgIface.SetFilter(pf)
if err != nil {
t.Fatalf("set packet filter: %v", err)
return nil, err
}
return wgIface, nil
}
func newDnsResolver(ip string, port int) *net.Resolver {
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 3,
}
addr := fmt.Sprintf("%s:%d", ip, port)
return d.DialContext(ctx, network, addr)
},
}
} }

View File

@@ -0,0 +1,18 @@
package dns
import (
"github.com/miekg/dns"
)
const (
defaultPort = 53
)
type service interface {
Listen() error
Stop()
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int
RuntimeIP() string
}

View File

@@ -0,0 +1,145 @@
package dns
import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
customPort = 5053
defaultIP = "127.0.0.1"
customIP = "127.0.0.153"
)
type serviceViaListener struct {
wgInterface WGIface
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
runtimeIP string
runtimePort int
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
}
return s
}
func (s *serviceViaListener) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
if err != nil {
log.Errorf("failed to eval runtime address: %s", err)
return err
}
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
return nil
}
func (s *serviceViaListener) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaListener) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaListener) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaListener) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
}
return s.getFirstListenerAvailable()
}

View File

@@ -0,0 +1,139 @@
package dns
import (
"fmt"
"math/big"
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type serviceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
runtimePort int
udpFilterHookID string
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
runtimePort: defaultPort,
}
return s
}
func (s *serviceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return err
}
s.listenerIsRunning = true
log.Debugf("dns service listening on: %s", s.RuntimeIP())
return nil
}
func (s *serviceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
s.listenerIsRunning = false
}
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}

View File

@@ -0,0 +1,31 @@
package dns
import (
"net"
"testing"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -15,7 +15,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool MatchOnly bool
} }
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface.Name()) iface, err := net.InterfaceByName(wgInterface.Name())
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -0,0 +1,14 @@
//go:build !windows
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
}

View File

@@ -0,0 +1,13 @@
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
GetInterfaceGUIDString() (string, error)
}

View File

@@ -190,23 +190,25 @@ func (e *Engine) Start() error {
} }
var routes []*route.Route var routes []*route.Route
var dnsCfg *nbdns.Config
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
routes, dnsCfg, err = e.readInitialSettings() routes, err = e.readInitialSettings()
if err != nil { if err != nil {
return err return err
} }
}
if e.dnsServer == nil { if e.dnsServer == nil {
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
go e.mobileDep.DnsReadyListener.OnReady()
}
} else {
// todo fix custom address // todo fix custom address
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg) if e.dnsServer == nil {
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
if err != nil { if err != nil {
e.close() e.close()
return err return err
} }
e.dnsServer = dnsServer }
} }
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
@@ -605,6 +607,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// cleanup request, most likely our peer has been deleted // cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() { if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers() err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil { if err != nil {
return err return err
} }
@@ -624,6 +627,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return err return err
} }
e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys // update SSHServer by adding remote peer SSH keys
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() { for _, config := range networkMap.GetRemotePeers() {
@@ -759,17 +764,13 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
} }
e.peerConns[peerKey] = conn e.peerConns[peerKey] = conn
err = e.statusRecorder.AddPeer(peerKey) err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil { if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
} }
go e.connWorker(conn, peerKey) go e.connWorker(conn, peerKey)
} }
err := e.statusRecorder.UpdatePeerFQDN(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerKey, err)
}
return nil return nil
} }
@@ -1046,14 +1047,13 @@ func (e *Engine) close() {
} }
} }
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { func (e *Engine) readInitialSettings() ([]*route.Route, error) {
netMap, err := e.mgmClient.GetNetworkMap() netMap, err := e.mgmClient.GetNetworkMap()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
routes := toRoutes(netMap.GetRoutes()) routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, nil
return routes, &dnsCfg, nil
} }
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {

View File

@@ -1,6 +1,7 @@
package internal package internal
import ( import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -11,4 +12,6 @@ type MobileDependency struct {
TunAdapter iface.TunAdapter TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
RouteListener routemanager.RouteListener RouteListener routemanager.RouteListener
HostDNSAddresses []string
DnsReadyListener dns.ReadyListener
} }

View File

@@ -59,6 +59,11 @@ type Status struct {
mgmAddress string mgmAddress string
signalAddress string signalAddress string
notifier *notifier notifier *notifier
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
peerListChangedForNotification bool
} }
// NewRecorder returns a new Status instance // NewRecorder returns a new Status instance
@@ -78,11 +83,13 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
defer d.mux.Unlock() defer d.mux.Unlock()
d.offlinePeers = make([]State, len(replacement)) d.offlinePeers = make([]State, len(replacement))
copy(d.offlinePeers, replacement) copy(d.offlinePeers, replacement)
d.notifyPeerListChanged()
// todo we should set to true in case if the list changed only
d.peerListChangedForNotification = true
} }
// AddPeer adds peer to Daemon status map // AddPeer adds peer to Daemon status map
func (d *Status) AddPeer(peerPubKey string) error { func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -90,7 +97,12 @@ func (d *Status) AddPeer(peerPubKey string) error {
if ok { if ok {
return errors.New("peer already exist") return errors.New("peer already exist")
} }
d.peers[peerPubKey] = State{PubKey: peerPubKey, ConnStatus: StatusDisconnected} d.peers[peerPubKey] = State{
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
}
d.peerListChangedForNotification = true
return nil return nil
} }
@@ -112,13 +124,13 @@ func (d *Status) RemovePeer(peerPubKey string) error {
defer d.mux.Unlock() defer d.mux.Unlock()
_, ok := d.peers[peerPubKey] _, ok := d.peers[peerPubKey]
if ok { if !ok {
delete(d.peers, peerPubKey) return errors.New("no peer with to remove")
return nil
} }
d.notifyPeerListChanged() delete(d.peers, peerPubKey)
return errors.New("no peer with to remove") d.peerListChangedForNotification = true
return nil
} }
// UpdatePeerState updates peer status // UpdatePeerState updates peer status
@@ -188,10 +200,23 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
peerState.FQDN = fqdn peerState.FQDN = fqdn
d.peers[peerPubKey] = peerState d.peers[peerPubKey] = peerState
d.notifyPeerListChanged()
return nil return nil
} }
// FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() {
d.mux.Lock()
if !d.peerListChangedForNotification {
d.mux.Unlock()
return
}
d.peerListChangedForNotification = false
d.mux.Unlock()
d.notifyPeerListChanged()
}
// GetPeerStateChangeNotifier returns a change notifier channel for a peer // GetPeerStateChangeNotifier returns a change notifier channel for a peer
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock() d.mux.Lock()

View File

@@ -9,13 +9,13 @@ import (
func TestAddPeer(t *testing.T) { func TestAddPeer(t *testing.T) {
key := "abc" key := "abc"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
err := status.AddPeer(key) err := status.AddPeer(key, "abc.netbird")
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
_, exists := status.peers[key] _, exists := status.peers[key]
assert.True(t, exists, "value was found") assert.True(t, exists, "value was found")
err = status.AddPeer(key) err = status.AddPeer(key, "abc.netbird")
assert.Error(t, err, "should return error on duplicate") assert.Error(t, err, "should return error on duplicate")
} }
@@ -23,7 +23,7 @@ func TestAddPeer(t *testing.T) {
func TestGetPeer(t *testing.T) { func TestGetPeer(t *testing.T) {
key := "abc" key := "abc"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
err := status.AddPeer(key) err := status.AddPeer(key, "abc.netbird")
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
peerStatus, err := status.GetPeer(key) peerStatus, err := status.GetPeer(key)

View File

@@ -6,8 +6,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -30,34 +28,14 @@ func genKey(format string, input string) string {
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager // NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func NewFirewall(parentCTX context.Context) firewallManager { func NewFirewall(parentCTX context.Context) firewallManager {
ctx, cancel := context.WithCancel(parentCTX) manager, err := newNFTablesManager(parentCTX)
if err == nil {
if isIptablesSupported() { log.Debugf("nftables firewall manager will be used")
log.Debugf("iptables is supported")
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
return &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
}
log.Debugf("iptables is not supported, using nftables")
manager := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
return manager return manager
} }
log.Debugf("fallback to iptables firewall manager: %s", err)
return newIptablesManager(parentCTX)
}
func getInPair(pair routerPair) routerPair { func getInPair(pair routerPair) routerPair {
return routerPair{ return routerPair{

View File

@@ -49,6 +49,28 @@ type iptablesManager struct {
mux sync.Mutex mux sync.Mutex
} }
func newIptablesManager(parentCtx context.Context) *iptablesManager {
ctx, cancel := context.WithCancel(parentCtx)
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if !isIptablesClientAvailable(ipv4Client) {
log.Infof("iptables is missing for ipv4")
ipv4Client = nil
}
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if !isIptablesClientAvailable(ipv6Client) {
log.Infof("iptables is missing for ipv6")
ipv6Client = nil
}
return &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
}
// CleanRoutingRules cleans existing iptables resources that we created by the agent // CleanRoutingRules cleans existing iptables resources that we created by the agent
func (i *iptablesManager) CleanRoutingRules() { func (i *iptablesManager) CleanRoutingRules() {
i.mux.Lock() i.mux.Lock()
@@ -61,6 +83,7 @@ func (i *iptablesManager) CleanRoutingRules() {
log.Debug("flushing tables") log.Debug("flushing tables")
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v" errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
if i.ipv4Client != nil {
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil { if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
@@ -70,7 +93,9 @@ func (i *iptablesManager) CleanRoutingRules() {
if err != nil { if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
} }
}
if i.ipv6Client != nil {
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil { if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
@@ -80,6 +105,7 @@ func (i *iptablesManager) CleanRoutingRules() {
if err != nil { if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
} }
}
log.Info("done cleaning up iptables rules") log.Info("done cleaning up iptables rules")
} }
@@ -96,6 +122,7 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
errMSGFormat := "iptables: failed creating %s chain %s,error: %v" errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
if i.ipv4Client != nil {
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain) err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil { if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
@@ -106,7 +133,14 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
} }
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain) err = i.restoreRules(i.ipv4Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
}
}
if i.ipv6Client != nil {
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil { if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
} }
@@ -116,17 +150,13 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
} }
err = i.restoreRules(i.ipv4Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
}
err = i.restoreRules(i.ipv6Client) err = i.restoreRules(i.ipv6Client)
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err) return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
} }
}
err = i.addJumpRules() err := i.addJumpRules()
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while creating jump rules: %v", err) return fmt.Errorf("iptables: error while creating jump rules: %v", err)
} }
@@ -140,12 +170,13 @@ func (i *iptablesManager) addJumpRules() error {
if err != nil { if err != nil {
return err return err
} }
if i.ipv4Client != nil {
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding) rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil { if err != nil {
return err return err
} }
i.rules[ipv4][ipv4Forwarding] = rule i.rules[ipv4][ipv4Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv4Nat) rule = append(iptablesDefaultNatRule, ipv4Nat)
@@ -154,8 +185,10 @@ func (i *iptablesManager) addJumpRules() error {
return err return err
} }
i.rules[ipv4][ipv4Nat] = rule i.rules[ipv4][ipv4Nat] = rule
}
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding) if i.ipv6Client != nil {
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding)
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil { if err != nil {
return err return err
@@ -168,6 +201,7 @@ func (i *iptablesManager) addJumpRules() error {
return err return err
} }
i.rules[ipv6][ipv6Nat] = rule i.rules[ipv6][ipv6Nat] = rule
}
return nil return nil
} }
@@ -177,6 +211,7 @@ func (i *iptablesManager) cleanJumpRules() error {
var err error var err error
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v" errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
rule, found := i.rules[ipv4][ipv4Forwarding] rule, found := i.rules[ipv4][ipv4Forwarding]
if i.ipv4Client != nil {
if found { if found {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding) log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
@@ -192,6 +227,8 @@ func (i *iptablesManager) cleanJumpRules() error {
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err) return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
} }
} }
}
if i.ipv6Client == nil {
rule, found = i.rules[ipv6][ipv6Forwarding] rule, found = i.rules[ipv6][ipv6Forwarding]
if found { if found {
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding) log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
@@ -208,6 +245,7 @@ func (i *iptablesManager) cleanJumpRules() error {
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err) return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
} }
} }
}
return nil return nil
} }
@@ -437,3 +475,8 @@ func getIptablesRuleType(table string) string {
} }
return ruleType return ruleType
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -16,17 +16,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
ctx, cancel := context.WithCancel(context.TODO()) manager := newIptablesManager(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
@@ -37,21 +27,21 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4") require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...) exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist") require.True(t, exists, "forwarding rule should exist")
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...) exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting rule should exist")
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6") require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...) exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist") require.True(t, exists, "forwarding rule should exist")
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...) exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting rule should exist")
@@ -64,13 +54,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
forward4RuleKey := genKey(forwardingFormat, pair.ID) forward4RuleKey := genKey(forwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination) forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...) err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
nat4RuleKey := genKey(natFormat, pair.ID) nat4RuleKey := genKey(natFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination) nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...) err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
pair = routerPair{ pair = routerPair{
@@ -83,13 +73,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
forward6RuleKey := genKey(forwardingFormat, pair.ID) forward6RuleKey := genKey(forwardingFormat, pair.ID)
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination) forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...) err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
nat6RuleKey := genKey(natFormat, pair.ID) nat6RuleKey := genKey(natFormat, pair.ID)
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination) nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...) err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4) delete(manager.rules, ipv4)

View File

@@ -19,6 +19,9 @@ const (
nftablesTable = "netbird-rt" nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd" nftablesRoutingForwardingChain = "netbird-rt-fwd"
nftablesRoutingNatChain = "netbird-rt-nat" nftablesRoutingNatChain = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
) )
// constants needed to create nftable rules // constants needed to create nftable rules
@@ -78,9 +81,36 @@ type nftablesManager struct {
tableIPv6 *nftables.Table tableIPv6 *nftables.Table
chains map[string]map[string]*nftables.Chain chains map[string]map[string]*nftables.Chain
rules map[string]*nftables.Rule rules map[string]*nftables.Rule
filterTable *nftables.Table
defaultForwardRules []*nftables.Rule
mux sync.Mutex mux sync.Mutex
} }
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
ctx, cancel := context.WithCancel(parentCtx)
mgr := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2),
}
err := mgr.isSupported()
if err != nil {
return nil, err
}
err = mgr.readFilterTable()
if err != nil {
return nil, err
}
return mgr, nil
}
// CleanRoutingRules cleans existing nftables rules from the system // CleanRoutingRules cleans existing nftables rules from the system
func (n *nftablesManager) CleanRoutingRules() { func (n *nftablesManager) CleanRoutingRules() {
n.mux.Lock() n.mux.Lock()
@@ -90,6 +120,13 @@ func (n *nftablesManager) CleanRoutingRules() {
n.conn.FlushTable(n.tableIPv6) n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4) n.conn.FlushTable(n.tableIPv4)
} }
if n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delete forward rule: %s", err)
}
}
log.Debugf("flushing tables result in: %v error", n.conn.Flush()) log.Debugf("flushing tables result in: %v error", n.conn.Flush())
} }
@@ -222,6 +259,112 @@ func (n *nftablesManager) refreshRulesMap() error {
return nil return nil
} }
func (n *nftablesManager) readFilterTable() error {
tables, err := n.conn.ListTables()
if err != nil {
return err
}
for _, t := range tables {
if t.Name == "filter" {
n.filterTable = t
return nil
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil {
return nil
}
err := n.refreshDefaultForwardRule()
if err != nil {
return err
}
for i, r := range n.defaultForwardRules {
err = n.conn.DelRule(r)
if err != nil {
log.Errorf("failed to delete forward rule (%d): %s", i, err)
}
n.defaultForwardRules[i] = nil
}
return nil
}
func (n *nftablesManager) refreshDefaultForwardRule() error {
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
if err != nil {
return fmt.Errorf("unable to list rules in forward chain: %s", err)
}
found := false
for i, r := range n.defaultForwardRules {
for _, rule := range rules {
if string(rule.UserData) == string(r.UserData) {
n.defaultForwardRules[i] = rule
found = true
break
}
}
}
if !found {
return fmt.Errorf("unable to find forward accept rule")
}
return nil
}
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
src := generateCIDRMatcherExpressions("source", sourceNetwork)
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{
Kind: expr.VerdictAccept,
})...)
r := &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
n.defaultForwardRules[0] = n.conn.AddRule(r)
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{
Kind: expr.VerdictAccept,
})...)
r = &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
n.defaultForwardRules[1] = n.conn.AddRule(r)
return nil
}
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled // checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
_, foundIPv4 := n.rules[ipv4Forwarding] _, foundIPv4 := n.rules[ipv4Forwarding]
@@ -275,6 +418,14 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
} }
} }
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
err = n.acceptForwardRule(pair.source)
if err != nil {
log.Errorf("unable to create default forward rule: %s", err)
}
log.Debugf("default accept forward rule added")
}
err = n.conn.Flush() err = n.conn.Flush()
if err != nil { if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err) return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
@@ -355,6 +506,13 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return err return err
} }
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delte default fwd rule: %s", err)
}
}
err = n.conn.Flush() err = n.conn.Flush()
if err != nil { if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err) return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
@@ -386,6 +544,14 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
return nil return nil
} }
func (n *nftablesManager) isSupported() error {
_, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables is not supported: %s", err)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction // getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch { switch {

View File

@@ -14,21 +14,16 @@ import (
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) manager, err := newNFTablesManager(context.TODO())
if err != nil {
manager := &nftablesManager{ t.Fatalf("failed to create nftables manager: %s", err)
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
} }
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
@@ -134,21 +129,16 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
for _, testCase := range insertRuleTestCases { for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) manager, err := newNFTablesManager(context.TODO())
if err != nil {
manager := &nftablesManager{ t.Fatalf("failed to create nftables manager: %s", err)
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
} }
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair) err = manager.InsertRoutingRules(testCase.inputPair)
@@ -239,21 +229,16 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
for _, testCase := range removeRuleTestCases { for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) manager, err := newNFTablesManager(context.TODO())
if err != nil {
manager := &nftablesManager{ t.Fatalf("failed to create nftables manager: %s", err)
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
} }
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4 table := manager.tableIPv4

View File

@@ -0,0 +1,82 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd
package routemanager
import (
"fmt"
"net"
"net/netip"
"syscall"
"golang.org/x/net/route"
)
// selected BSD Route flags.
const (
RTF_UP = 0x1
RTF_GATEWAY = 0x2
RTF_HOST = 0x4
RTF_REJECT = 0x8
RTF_DYNAMIC = 0x10
RTF_MODIFIED = 0x20
RTF_STATIC = 0x800
RTF_BLACKHOLE = 0x1000
RTF_LOCAL = 0x200000
RTF_BROADCAST = 0x400000
RTF_MULTICAST = 0x800000
)
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
if err != nil {
return false, err
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
if err != nil {
return false, err
}
for _, msg := range msgs {
m := msg.(*route.RouteMessage)
if m.Version < 3 || m.Version > 5 {
return false, fmt.Errorf("unexpected RIB message version: %d", m.Version)
}
if m.Type != 4 /* RTM_GET */ {
return true, fmt.Errorf("unexpected RIB message type: %d", m.Type)
}
if m.Flags&RTF_UP == 0 ||
m.Flags&(RTF_REJECT|RTF_BLACKHOLE) != 0 {
continue
}
dst, err := toIPAddr(m.Addrs[0])
if err != nil {
return true, fmt.Errorf("unexpected RIB destination: %v", err)
}
mask, _ := toIPAddr(m.Addrs[2])
cidr, _ := net.IPMask(mask.To4()).Size()
if dst.String() == prefix.Addr().String() && cidr == prefix.Bits() {
return true, nil
}
}
return false, nil
}
func toIPAddr(a route.Addr) (net.IP, error) {
switch t := a.(type) {
case *route.Inet4Addr:
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
return ip, nil
case *route.Inet6Addr:
ip := make(net.IP, net.IPv6len)
copy(ip, t.IP[:])
return ip, nil
default:
return net.IP{}, fmt.Errorf("unknown family: %v", t)
}
}

View File

@@ -6,10 +6,28 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"syscall"
"unsafe"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
) )
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
type routeInfoInMemory struct {
Family byte
DstLen byte
SrcLen byte
TOS byte
Table byte
Protocol byte
Scope byte
Type byte
Flags uint32
}
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
func addToRouteTable(prefix netip.Prefix, addr string) error { func addToRouteTable(prefix netip.Prefix, addr string) error {
@@ -61,6 +79,45 @@ func removeFromRouteTable(prefix netip.Prefix) error {
return nil return nil
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
if err != nil {
return true, err
}
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return true, err
}
loop:
for _, m := range msgs {
switch m.Header.Type {
case syscall.NLMSG_DONE:
break loop
case syscall.RTM_NEWROUTE:
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
attrs, err := syscall.ParseNetlinkRouteAttr(&m)
if err != nil {
return true, err
}
if rt.Family != syscall.AF_INET {
continue loop
}
for _, attr := range attrs {
if attr.Attr.Type == syscall.RTA_DST {
ip := net.IP(attr.Value)
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
cidr, _ := mask.Size()
if ip.String() == prefix.Addr().String() && cidr == prefix.Bits() {
return true, nil
}
}
}
}
}
return false, nil
}
func enableIPForwarding() error { func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath) bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil { if err != nil {

View File

@@ -14,19 +14,26 @@ import (
var errRouteNotFound = fmt.Errorf("route not found") var errRouteNotFound = fmt.Errorf("route not found")
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil && err != errRouteNotFound {
return err
}
prefixGateway, err := getExistingRIBRouteGateway(prefix)
if err != nil && err != errRouteNotFound { if err != nil && err != errRouteNotFound {
return err return err
} }
if prefixGateway != nil && !prefixGateway.Equal(gateway) { gatewayIP := netip.MustParseAddr(defaultGateway.String())
log.Warnf("skipping adding a new route for network %s because it already exists and is pointing to the non default gateway: %s", prefix, prefixGateway) if prefix.Contains(gatewayIP) {
log.Warnf("skipping adding a new route for network %s because it overlaps with the default gateway: %s", prefix, gatewayIP)
return nil return nil
} }
ok, err := existsInRouteTable(prefix)
if err != nil {
return err
}
if ok {
log.Warnf("skipping adding a new route for network %s because it already exists", prefix)
return nil
}
return addToRouteTable(prefix, addr) return addToRouteTable(prefix, addr)
} }
@@ -53,6 +60,7 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
log.Errorf("getting routes returned an error: %v", err) log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound return nil, errRouteNotFound
} }
if gateway == nil { if gateway == nil {
return preferredSrc, nil return preferredSrc, nil
} }

View File

@@ -1,13 +1,19 @@
package routemanager package routemanager
import ( import (
"bytes"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require"
"net" "net"
"net/netip" "net/netip"
"os"
"strings"
"testing" "testing"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/iface"
) )
func TestAddRemoveRoutes(t *testing.T) { func TestAddRemoveRoutes(t *testing.T) {
@@ -114,3 +120,98 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
} }
} }
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
fmt.Println("defaultGateway: ", defaultGateway)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
testCases := []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Add And Remove random Route",
prefix: netip.MustParsePrefix("99.99.99.99/32"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
shouldAddRoute: false,
},
{
name: "Should Add Route if bigger network exists",
prefix: netip.MustParsePrefix("100.100.100.0/24"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
MockAddr := wgInterface.Address().IP.String()
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
// test if route exists after adding
ok, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "should not return err")
require.True(t, ok, "route should exist")
// remove route again if added
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr)
require.NoError(t, err, "should not return err")
}
// route should either not have been added or should have been removed
// In case of already existing route, it should not have been added (but still exist)
ok, err := existsInRouteTable(testCase.prefix)
fmt.Println("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}

View File

@@ -0,0 +1,37 @@
//go:build windows
// +build windows
package routemanager
import (
"net"
"net/netip"
"github.com/yusufpapurcu/wmi"
)
type Win32_IP4RouteTable struct {
Destination string
Mask string
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
var routes []Win32_IP4RouteTable
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
err := wmi.Query(query, &routes)
if err != nil {
return true, err
}
for _, route := range routes {
ip := net.ParseIP(route.Mask)
ip = ip.To4()
mask := net.IPv4Mask(ip[0], ip[1], ip[2], ip[3])
cidr, _ := mask.Size()
if route.Destination == prefix.Addr().String() && cidr == prefix.Bits() {
return true, nil
}
}
return false, nil
}

View File

@@ -102,7 +102,7 @@ func (s *Server) Start() error {
} }
go func() { go func() {
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil { if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
log.Errorf("init connections: %v", err) log.Errorf("init connections: %v", err)
} }
}() }()
@@ -391,7 +391,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
} }
go func() { go func() {
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil { if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
log.Errorf("run client connection: %v", err) log.Errorf("run client connection: %v", err)
return return
} }

View File

@@ -2,9 +2,6 @@ package ssh
import ( import (
"fmt" "fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"io" "io"
"net" "net"
"os" "os"
@@ -13,11 +10,22 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
) )
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server // DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
const DefaultSSHPort = 44338 const DefaultSSHPort = 44338
// TerminalTimeout is the timeout for terminal session to be ready
const TerminalTimeout = 10 * time.Second
// TerminalBackoffDelay is the delay between terminal session readiness checks
const TerminalBackoffDelay = 500 * time.Millisecond
// DefaultSSHServer is a function that creates DefaultServer // DefaultSSHServer is a function that creates DefaultServer
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
return newDefaultServer(hostKeyPEM, addr) return newDefaultServer(hostKeyPEM, addr)
@@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
} }
}() }()
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
localUser, err := userNameLookup(session.User()) localUser, err := userNameLookup(session.User())
if err != nil { if err != nil {
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
@@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
} }
} }
log.Debugf("Login command: %s", cmd.String())
file, err := pty.Start(cmd) file, err := pty.Start(cmd)
if err != nil { if err != nil {
log.Errorf("failed starting SSH server %v", err) log.Errorf("failed starting SSH server %v", err)
@@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
return return
} }
} }
log.Debugf("SSH session ended")
} }
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
@@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
// stdin // stdin
_, err := io.Copy(file, session) _, err := io.Copy(file, session)
if err != nil { if err != nil {
_ = session.Exit(1)
return return
} }
}() }()
go func() { // AWS Linux 2 machines need some time to open the terminal so we need to wait for it
timer := time.NewTimer(TerminalTimeout)
for {
select {
case <-timer.C:
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
_ = session.Exit(1)
return
default:
// stdout // stdout
_, err := io.Copy(session, file) writtenBytes, err := io.Copy(session, file)
if err != nil { if err != nil && writtenBytes != 0 {
_ = session.Exit(0)
return return
} }
}() time.Sleep(TerminalBackoffDelay)
}
}
} }
// Start starts SSH server. Blocking // Start starts SSH server. Blocking

23
go.mod
View File

@@ -17,12 +17,12 @@ require (
github.com/spf13/cobra v1.6.1 github.com/spf13/cobra v1.6.1
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.7.0 golang.org/x/crypto v0.9.0
golang.org/x/sys v0.8.0 golang.org/x/sys v0.8.0
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.52.3 google.golang.org/grpc v1.55.0
google.golang.org/protobuf v1.30.0 google.golang.org/protobuf v1.30.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
@@ -48,6 +48,7 @@ require (
github.com/mdlayher/socket v0.4.0 github.com/mdlayher/socket v0.4.0
github.com/miekg/dns v1.1.43 github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2 github.com/pion/logging v0.2.2
@@ -57,6 +58,7 @@ require (
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
github.com/yusufpapurcu/wmi v1.2.3
go.opentelemetry.io/otel v1.11.1 go.opentelemetry.io/otel v1.11.1
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 go.opentelemetry.io/otel/exporters/prometheus v0.33.0
go.opentelemetry.io/otel/metric v0.33.0 go.opentelemetry.io/otel/metric v0.33.0
@@ -65,18 +67,22 @@ require (
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.10.0 golang.org/x/net v0.10.0
golang.org/x/sync v0.1.0 golang.org/x/oauth2 v0.8.0
golang.org/x/sync v0.2.0
golang.org/x/term v0.8.0 golang.org/x/term v0.8.0
google.golang.org/api v0.126.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
cloud.google.com/go/compute v1.19.3 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
github.com/BurntSushi/toml v1.2.1 // indirect github.com/BurntSushi/toml v1.2.1 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
@@ -92,9 +98,14 @@ require (
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.2.3 // indirect github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect github.com/go-stack/stack v1.8.0 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/s2a-go v0.1.4 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/native v1.0.0 // indirect github.com/josharian/native v1.0.0 // indirect
@@ -120,17 +131,17 @@ require (
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
github.com/yuin/goldmark v1.4.13 // indirect github.com/yuin/goldmark v1.4.13 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/otel/sdk v1.11.1 // indirect go.opentelemetry.io/otel/sdk v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/image v0.5.0 // indirect golang.org/x/image v0.5.0 // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/oauth2 v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect golang.org/x/text v0.9.0 // indirect
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.6.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect

48
go.sum
View File

@@ -20,7 +20,11 @@ cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvf
cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg=
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
cloud.google.com/go/compute v1.19.3 h1:DcTwsFgGev/wV5+q8o2fzgcHOaac+DKGC91ZlvpsQds=
cloud.google.com/go/compute v1.19.3/go.mod h1:qxvISKp/gYnXkSAD1ppcSOveRAmzxicEv/JlizULFrI=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
@@ -87,8 +91,9 @@ github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInq
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
@@ -101,6 +106,7 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI=
github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
@@ -163,6 +169,7 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0=
github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
@@ -219,6 +226,7 @@ github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0= github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -250,12 +258,14 @@ github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff/go.mod h1:wfqRWLHRBs
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ= github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE=
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
@@ -321,13 +331,19 @@ github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/s2a-go v0.1.4 h1:1kZ/sQM3srePvKs3tXAvQzo66XfcReoqFpIpIccE7Oc=
github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A=
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k=
github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/gax-go/v2 v2.10.0 h1:ebSgKfMxynOdxw8QQuFOKMgomqeLGPqNLQox2bo42zg=
github.com/googleapis/gax-go/v2 v2.10.0/go.mod h1:4UOEnMCrxsSqQ940WnTiD6qJ63le2ev3xfyagutxiPw=
github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY=
github.com/googleapis/gnostic v0.4.0/go.mod h1:on+2t9HRStVgn95RSsFWFz+6Q0Snyqv1awfrALZdbtU= github.com/googleapis/gnostic v0.4.0/go.mod h1:on+2t9HRStVgn95RSsFWFz+6Q0Snyqv1awfrALZdbtU=
github.com/googleapis/gnostic v0.5.1/go.mod h1:6U4PtQXGIEt/Z3h5MAT7FNofLnw9vXk2cUuW7uA/OeU= github.com/googleapis/gnostic v0.5.1/go.mod h1:6U4PtQXGIEt/Z3h5MAT7FNofLnw9vXk2cUuW7uA/OeU=
@@ -469,6 +485,8 @@ github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw= github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=
@@ -661,6 +679,8 @@ github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
@@ -668,6 +688,8 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4= go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE= go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 h1:xXhPj7SLKWU5/Zd4Hxmd+X1C4jdmvc0Xy+kvjFx2z60= go.opentelemetry.io/otel/exporters/prometheus v0.33.0 h1:xXhPj7SLKWU5/Zd4Hxmd+X1C4jdmvc0Xy+kvjFx2z60=
@@ -697,10 +719,11 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -831,8 +854,9 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -949,6 +973,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
@@ -1041,6 +1066,8 @@ google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0M
google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc=
google.golang.org/api v0.126.0 h1:q4GJq+cAdMAC7XP7njvQ4tvohGLiSlytuL4BQxbIZ+o=
google.golang.org/api v0.126.0/go.mod h1:mBwVAtz+87bEN6CbA1GtZPDOqY2R5ONPqJeIlvyo4Aw=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
@@ -1082,8 +1109,10 @@ google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6D
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20210722135532-667f2b7c528f/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48= google.golang.org/genproto v0.0.0-20210722135532-667f2b7c528f/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48=
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 h1:a2S6M0+660BgMNl++4JPlcAO/CjkqYItDEZwkoDQK7c= google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc h1:8DyZCyvI8mE1IdLy/60bS+52xfymkE72wv1asokgtao=
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg= google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc h1:kVKPf/IiYSBWEWtkIn6wZXwWGCnLKcC8oWfZvXjsGnM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc h1:XSJ8Vk1SWuNr8S18z1NZSziL0CPIXLCCMDOEFtHBOFc=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
@@ -1102,9 +1131,10 @@ google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTp
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE=
google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ=
google.golang.org/grpc v1.51.0-dev/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= google.golang.org/grpc v1.51.0-dev/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
google.golang.org/grpc v1.52.3 h1:pf7sOysg4LdgBqduXveGKrcEwbStiK2rtfghdzlUYDQ= google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag=
google.golang.org/grpc v1.52.3/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5vorUY= google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=

View File

@@ -5,7 +5,6 @@ import (
"sync" "sync"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
@@ -34,7 +33,6 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error { func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create(mIFaceArgs) return w.tun.Create(mIFaceArgs)
} }

View File

@@ -7,7 +7,6 @@ import (
"sync" "sync"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
@@ -38,6 +37,5 @@ func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
func (w *WGIface) Create() error { func (w *WGIface) Create() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create() return w.tun.Create()
} }

View File

@@ -34,6 +34,7 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
} }
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error { func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
log.Info("create tun interface")
var err error var err error
routesString := t.routesToString(mIFaceArgs.Routes) routesString := t.routesToString(mIFaceArgs.Routes)
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString) t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)

View File

@@ -12,14 +12,14 @@ import (
func (c *tunDevice) Create() error { func (c *tunDevice) Create() error {
if WireGuardModuleIsLoaded() { if WireGuardModuleIsLoaded() {
log.Info("using kernel WireGuard") log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
return c.createWithKernel() return c.createWithKernel()
} }
if !tunModuleIsLoaded() { if !tunModuleIsLoaded() {
return fmt.Errorf("couldn't check or load tun module") return fmt.Errorf("couldn't check or load tun module")
} }
log.Info("using userspace WireGuard") log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
var err error var err error
c.netInterface, err = c.createWithUserspace() c.netInterface, err = c.createWithUserspace()
if err != nil { if err != nil {

View File

@@ -56,7 +56,7 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
c.wrapper = newDeviceWrapper(tunIface) c.wrapper = newDeviceWrapper(tunIface)
// We need to create a wireguard-go device and listen to configuration requests // We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] ")) tunDev := device.NewDevice(c.wrapper, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDev.Up() err = tunDev.Up()
if err != nil { if err != nil {
_ = tunIface.Close() _ = tunIface.Close()

View File

@@ -97,16 +97,9 @@ curl "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" -q -o openid-configuration.js
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' openid-configuration.json) export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' openid-configuration.json)
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' openid-configuration.json) export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' openid-configuration.json)
export NETBIRD_AUTH_SUPPORTED_SCOPES=$(jq -r '.scopes_supported | join(" ")' openid-configuration.json)
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' openid-configuration.json) export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' openid-configuration.json)
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' openid-configuration.json) export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' openid-configuration.json)
if [ "$NETBIRD_USE_AUTH0" == "true" ]; then
export NETBIRD_AUTH_SUPPORTED_SCOPES="openid profile email offline_access api email_verified"
else
export NETBIRD_AUTH_SUPPORTED_SCOPES="openid profile email offline_access api"
fi
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
# user enabled Device Authorization Grant feature # user enabled Device Authorization Grant feature
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted" export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"

View File

@@ -11,7 +11,10 @@ NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
NETBIRD_AUTH_AUDIENCE="" NETBIRD_AUTH_AUDIENCE=""
# e.g. netbird-client # e.g. netbird-client
NETBIRD_AUTH_CLIENT_ID="" NETBIRD_AUTH_CLIENT_ID=""
NETBIRD_AUTH_CLIENT_SECRET="" # indicates the scopes that will be requested to the IDP
NETBIRD_AUTH_SUPPORTED_SCOPES=""
# NETBIRD_AUTH_CLIENT_SECRET is required only by Google workspace.
# NETBIRD_AUTH_CLIENT_SECRET=""
# if you want to use a custom claim for the user ID instead of 'sub', set it here # if you want to use a custom claim for the user ID instead of 'sub', set it here
# NETBIRD_AUTH_USER_ID_CLAIM="" # NETBIRD_AUTH_USER_ID_CLAIM=""
# indicates whether to use Auth0 or not: true or false # indicates whether to use Auth0 or not: true or false

View File

@@ -6,6 +6,7 @@ NETBIRD_DOMAIN=$CI_NETBIRD_DOMAIN
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://example.eu.auth0.com/.well-known/openid-configuration" NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://example.eu.auth0.com/.well-known/openid-configuration"
# e.g. netbird-client # e.g. netbird-client
NETBIRD_AUTH_CLIENT_ID=$CI_NETBIRD_AUTH_CLIENT_ID NETBIRD_AUTH_CLIENT_ID=$CI_NETBIRD_AUTH_CLIENT_ID
NETBIRD_AUTH_SUPPORTED_SCOPES=$CI_NETBIRD_AUTH_SUPPORTED_SCOPES
NETBIRD_AUTH_CLIENT_SECRET=$CI_NETBIRD_AUTH_CLIENT_SECRET NETBIRD_AUTH_CLIENT_SECRET=$CI_NETBIRD_AUTH_CLIENT_SECRET
# indicates whether to use Auth0 or not: true or false # indicates whether to use Auth0 or not: true or false
NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0 NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0

View File

@@ -218,7 +218,11 @@ var (
if !disableMetrics { if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager) idpManager := "disabled"
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
idpManager = config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
go metricsWorker.Run() go metricsWorker.Run()
} }

View File

@@ -34,6 +34,8 @@ const (
PublicCategory = "public" PublicCategory = "public"
PrivateCategory = "private" PrivateCategory = "private"
UnknownCategory = "unknown" UnknownCategory = "unknown"
GroupIssuedAPI = "api"
GroupIssuedJWT = "jwt"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerLoginExpiration = 24 * time.Hour
@@ -51,6 +53,7 @@ type AccountManager interface {
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, initiatorUserID string, targetUserID string) error DeleteUser(accountID, initiatorUserID string, targetUserID string) error
InviteUser(accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(accountID, userID string) ([]*SetupKey, error) ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
@@ -78,7 +81,7 @@ type AccountManager interface {
GetGroup(accountId, groupID string) (*Group, error) GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountID, userID string, group *Group) error SaveGroup(accountID, userID string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error) UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
DeleteGroup(accountId, groupID string) error DeleteGroup(accountId, userId, groupID string) error
ListGroups(accountId string) ([]*Group, error) ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerID string) error GroupAddPeer(accountId, groupID, peerID string) error
GroupDeletePeer(accountId, groupID, peerKey string) error GroupDeletePeer(accountId, groupID, peerKey string) error
@@ -139,6 +142,13 @@ type Settings struct {
// PeerLoginExpiration is a setting that indicates when peer login expires. // PeerLoginExpiration is a setting that indicates when peer login expires.
// Applies to all peers that have Peer.LoginExpirationEnabled set to true. // Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration PeerLoginExpiration time.Duration
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
// and add it to account groups.
JWTGroupsEnabled bool
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string
} }
// Copy copies the Settings struct // Copy copies the Settings struct
@@ -146,6 +156,8 @@ func (s *Settings) Copy() *Settings {
return &Settings{ return &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration, PeerLoginExpiration: s.PeerLoginExpiration,
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
} }
} }
@@ -612,6 +624,28 @@ func (a *Account) GetPeer(peerID string) *Peer {
return a.Peers[peerID] return a.Peers[peerID]
} }
// AddJWTGroups to existed groups if they does not exists
func (a *Account) AddJWTGroups(groups []string) (int, error) {
existedGroups := make(map[string]*Group)
for _, g := range a.Groups {
existedGroups[g.Name] = g
}
var count int
for _, name := range groups {
if _, ok := existedGroups[name]; !ok {
id := xid.New().String()
a.Groups[id] = &Group{
ID: id,
Name: name,
Issued: GroupIssuedJWT,
}
count++
}
}
return count, nil
}
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
@@ -1241,6 +1275,38 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
} }
} }
if account.Settings.JWTGroupsEnabled {
if account.Settings.JWTGroupsClaimName == "" {
log.Errorf("JWT groups are enabled but no claim name is set")
return account, user, nil
}
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if slice, ok := claim.([]interface{}); ok {
var groups []string
for _, item := range slice {
if g, ok := item.(string); ok {
groups = append(groups, g)
} else {
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
}
}
n, err := account.AddJWTGroups(groups)
if err != nil {
log.Errorf("failed to add JWT groups: %v", err)
}
if n > 0 {
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
}
}
} else {
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
}
} else {
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
}
}
return account, user, nil return account, user, nil
} }
@@ -1346,6 +1412,7 @@ func addAllGroup(account *Account) error {
allGroup := &Group{ allGroup := &Group{
ID: xid.New().String(), ID: xid.New().String(),
Name: "All", Name: "All",
Issued: GroupIssuedAPI,
} }
for _, peer := range account.Peers { for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID) allGroup.Peers = append(allGroup.Peers, peer.ID)
@@ -1373,33 +1440,28 @@ func addAllGroup(account *Account) error {
} }
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(accountId, userId, domain string) *Account { func newAccountWithId(accountID, userID, domain string) *Account {
log.Debugf("creating new account") log.Debugf("creating new account")
setupKeys := make(map[string]*SetupKey)
defaultKey := GenerateDefaultSetupKey()
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{},
SetupKeyUnlimitedUsage)
setupKeys[defaultKey.Key] = defaultKey
setupKeys[oneOffKey.Key] = oneOffKey
network := NewNetwork() network := NewNetwork()
peers := make(map[string]*Peer) peers := make(map[string]*Peer)
users := make(map[string]*User) users := make(map[string]*User)
routes := make(map[string]*route.Route) routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup) nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userId] = NewAdminUser(userId) users[userID] = NewAdminUser(userID)
dnsSettings := &DNSSettings{ dnsSettings := &DNSSettings{
DisabledManagementGroups: make([]string, 0), DisabledManagementGroups: make([]string, 0),
} }
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key) log.Debugf("created new account %s", accountID)
acc := &Account{ acc := &Account{
Id: accountId, Id: accountID,
SetupKeys: setupKeys, SetupKeys: setupKeys,
Network: network, Network: network,
Peers: peers, Peers: peers,
Users: users, Users: users,
CreatedBy: userId, CreatedBy: userID,
Domain: domain, Domain: domain,
Routes: routes, Routes: routes,
NameServerGroups: nameServersGroups, NameServerGroups: nameServersGroups,

View File

@@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -53,7 +54,7 @@ func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy
t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers)) t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers))
} }
if len(account.SetupKeys) != 2 { if len(account.SetupKeys) != 0 {
t.Errorf("expected account to have len(SetupKeys) = %v, got %v", 2, len(account.SetupKeys)) t.Errorf("expected account to have len(SetupKeys) = %v, got %v", 2, len(account.SetupKeys))
} }
@@ -460,6 +461,75 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
} }
} }
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
initAccount := newAccountWithId("", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount.Id = acc.Id
claims := jwtclaims.AuthorizationClaims{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
Domain: domain,
UserId: userId,
DomainCategory: "test-category",
Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*Group{}
for _, g := range account.Groups {
groupsByNames[g.Name] = g
}
g1, ok := groupsByNames["group1"]
require.True(t, ok, "group1 should be added to the account")
require.Equal(t, g1.Name, "group1", "group1 name should match")
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match")
g2, ok := groupsByNames["group2"]
require.True(t, ok, "group2 should be added to the account")
require.Equal(t, g2.Name, "group2", "group2 name should match")
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match")
})
}
func TestAccountManager_GetAccountFromPAT(t *testing.T) { func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId("account_id", "testuser", "") account := newAccountWithId("account_id", "testuser", "")
@@ -704,20 +774,21 @@ func TestAccountManager_AddPeer(t *testing.T) {
return return
} }
account, err := createAccount(manager, "test_account", "account_creator", "netbird.cloud") userID := "account_creator"
account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
serial := account.Network.CurrentSerial() // should be 0 serial := account.Network.CurrentSerial() // should be 0
var setupKey *SetupKey setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
for _, key := range account.SetupKeys { if err != nil {
setupKey = key return
} }
if setupKey == nil { if err != nil {
t.Errorf("expecting account to have a default setup key") t.Fatal("error creating setup key")
return return
} }
@@ -858,16 +929,13 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var setupKey *SetupKey setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
for _, key := range account.SetupKeys { if err != nil {
setupKey = key return
if setupKey.Type == SetupKeyReusable {
break
}
} }
if setupKey == nil { if err != nil {
t.Errorf("expecting account to have a default setup key") t.Fatal("error creating setup key")
return return
} }
@@ -1021,7 +1089,10 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
}() }()
if err := manager.DeleteGroup(account.Id, group.ID); err != nil { // clean policy is pre requirement for delete group
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
t.Errorf("delete group: %v", err) t.Errorf("delete group: %v", err)
return return
} }
@@ -1042,9 +1113,14 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var setupKey *SetupKey setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
for _, key := range account.SetupKeys { if err != nil {
setupKey = key return
}
if err != nil {
t.Fatal("error creating setup key")
return
} }
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()

View File

@@ -95,6 +95,8 @@ const (
UserBlocked UserBlocked
// UserUnblocked indicates that a user unblocked another user // UserUnblocked indicates that a user unblocked another user
UserUnblocked UserUnblocked
// GroupDeleted indicates that a user deleted group
GroupDeleted
) )
const ( const (
@@ -192,6 +194,8 @@ const (
UserBlockedMessage string = "User blocked" UserBlockedMessage string = "User blocked"
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity // UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
UserUnblockedMessage string = "User unblocked" UserUnblockedMessage string = "User unblocked"
// GroupDeletedMessage is a human-readable text message of the GroupDeleted activity
GroupDeletedMessage string = "Group deleted"
) )
// Activity that triggered an Event // Activity that triggered an Event
@@ -294,6 +298,8 @@ func (a Activity) Message() string {
return UserBlockedMessage return UserBlockedMessage
case UserUnblocked: case UserUnblocked:
return UserUnblockedMessage return UserUnblockedMessage
case GroupDeleted:
return GroupDeletedMessage
default: default:
return "UNKNOWN_ACTIVITY" return "UNKNOWN_ACTIVITY"
} }
@@ -342,6 +348,8 @@ func (a Activity) StringCode() string {
return "group.add" return "group.add"
case GroupUpdated: case GroupUpdated:
return "group.update" return "group.update"
case GroupDeleted:
return "group.delete"
case GroupRemovedFromPeer: case GroupRemovedFromPeer:
return "peer.group.delete" return "peer.group.delete"
case GroupAddedToPeer: case GroupAddedToPeer:

View File

@@ -2,13 +2,15 @@ package server
import ( import (
"fmt" "fmt"
"strconv"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"strconv"
) )
const defaultTTL = 300 const defaultTTL = 300
@@ -199,15 +201,27 @@ func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
for _, gID := range nsGroup.Groups { for _, gID := range nsGroup.Groups {
_, found := groupList[gID] _, found := groupList[gID]
if found { if found {
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy()) peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break break
} }
} }
} }
}
return peerNSGroups return peerNSGroups
} }
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
func peerIsNameserver(peer *Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peer.IP.Equal(ns.IP.AsSlice()) {
return true
}
}
return false
}
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) { func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers { for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels) label, err := getPeerHostLabel(peer.Name, peerLabels)

View File

@@ -1,10 +1,12 @@
package server package server
import ( import (
"net/netip"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@@ -17,6 +19,7 @@ const (
dnsAccountID = "testingAcc" dnsAccountID = "testingAcc"
dnsAdminUserID = "testingAdminUser" dnsAdminUserID = "testingAdminUser"
dnsRegularUserID = "testingRegularUser" dnsRegularUserID = "testingRegularUser"
dnsNSGroup1 = "ns1"
) )
func TestGetDNSSettings(t *testing.T) { func TestGetDNSSettings(t *testing.T) {
@@ -163,6 +166,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers") require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
dnsSettings := account.DNSSettings.Copy() dnsSettings := account.DNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
@@ -174,11 +178,11 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group") require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group") require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID) peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group") require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group") require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All")
} }
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
@@ -246,7 +250,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
return nil, err return nil, err
} }
_, _, err = am.AddPeer("", dnsAdminUserID, peer1) savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -284,6 +288,24 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
account.Groups[newGroup1.ID] = newGroup1 account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2 account.Groups[newGroup2.ID] = newGroup2
allGroup, err := account.GetGroupAll()
if err != nil {
return nil, err
}
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{
ID: dnsNSGroup1,
Name: "ns-group-1",
NameServers: []dns.NameServer{{
IP: netip.MustParseAddr(savedPeer1.IP.String()),
NSType: dns.UDPNameServerType,
Port: dns.DefaultDNSPort,
}},
Primary: true,
Enabled: true,
Groups: []string{allGroup.ID},
}
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -157,6 +157,14 @@ func restore(file string) (*FileStore, error) {
addPeerLabelsToAccount(account, existingLabels) addPeerLabelsToAccount(account, existingLabels)
} }
// TODO: delete this block after migration
// Set API as issuer for groups which has not this field
for _, group := range account.Groups {
if group.Issued == "" {
group.Issued = GroupIssuedAPI
}
}
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
@@ -281,6 +289,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
if account.Id == "" {
return status.Errorf(status.InvalidArgument, "account id should not be empty")
}
accountCopy := account.Copy() accountCopy := account.Copy()
s.Accounts[accountCopy.Id] = accountCopy s.Accounts[accountCopy.Id] = accountCopy
@@ -326,7 +338,7 @@ func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
delete(s.HashedPAT2TokenID, hashedToken) delete(s.HashedPAT2TokenID, hashedToken)
return s.persist(s.storeFile) return nil
} }
// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID // DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
@@ -336,7 +348,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
delete(s.TokenID2UserID, tokenID) delete(s.TokenID2UserID, tokenID)
return s.persist(s.storeFile) return nil
} }
// GetAccountByPrivateDomain returns account by private domain // GetAccountByPrivateDomain returns account by private domain

View File

@@ -262,6 +262,7 @@ func TestRestore(t *testing.T) {
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length") require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
} }
// TODO: outdated, delete this
func TestRestorePolicies_Migration(t *testing.T) { func TestRestorePolicies_Migration(t *testing.T) {
storeDir := t.TempDir() storeDir := t.TempDir()
@@ -296,6 +297,40 @@ func TestRestorePolicies_Migration(t *testing.T) {
"failed to restore a FileStore file - missing Account Policies Sources") "failed to restore a FileStore file - missing Account Policies Sources")
} }
func TestRestoreGroups_Migration(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
// create default group
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
account.Groups = map[string]*Group{
"cfefqs706sqkneg59g3g": {
ID: "cfefqs706sqkneg59g3g",
Name: "All",
},
}
err = store.SaveAccount(account)
require.NoError(t, err, "failed to save account")
// restore account with default group with empty Issue field
if store, err = NewFileStore(storeDir, nil); err != nil {
return
}
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
}
func TestGetAccountByPrivateDomain(t *testing.T) { func TestGetAccountByPrivateDomain(t *testing.T) {
storeDir := t.TempDir() storeDir := t.TempDir()

View File

@@ -1,11 +1,23 @@
package server package server
import ( import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
type GroupLinkError struct {
Resource string
Name string
}
func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
}
// Group of the peers for ACL // Group of the peers for ACL
type Group struct { type Group struct {
// ID of the group // ID of the group
@@ -14,6 +26,9 @@ type Group struct {
// Name visible in the UI // Name visible in the UI
Name string Name string
// Issued of the group
Issued string
// Peers list of the group // Peers list of the group
Peers []string Peers []string
} }
@@ -47,6 +62,7 @@ func (g *Group) Copy() *Group {
return &Group{ return &Group{
ID: g.ID, ID: g.ID,
Name: g.Name, Name: g.Name,
Issued: g.Issued,
Peers: g.Peers[:], Peers: g.Peers[:],
} }
} }
@@ -199,15 +215,80 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
} }
// DeleteGroup object of the peers // DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountId)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountId)
if err != nil { if err != nil {
return err return err
} }
g, ok := account.Groups[groupID]
if !ok {
return nil
}
// check route links
for _, r := range account.Routes {
for _, g := range r.Groups {
if g == groupID {
return &GroupLinkError{"route", r.NetID}
}
}
}
// check DNS links
for _, dns := range account.NameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return &GroupLinkError{"name server groups", dns.Name}
}
}
}
// check ACL links
for _, policy := range account.Policies {
for _, rule := range policy.Rules {
for _, src := range rule.Sources {
if src == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
for _, dst := range rule.Destinations {
if dst == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
}
}
// check setup key links
for _, setupKey := range account.SetupKeys {
for _, grp := range setupKey.AutoGroups {
if grp == groupID {
return &GroupLinkError{"setup key", setupKey.Name}
}
}
}
// check user links
for _, user := range account.Users {
for _, grp := range user.AutoGroups {
if grp == groupID {
return &GroupLinkError{"user", user.Id}
}
}
}
// check DisabledManagementGroups
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
if disabledMgmGrp == groupID {
return &GroupLinkError{"disabled DNS management groups", g.Name}
}
}
delete(account.Groups, groupID) delete(account.Groups, groupID)
account.Network.IncSerial() account.Network.IncSerial()
@@ -215,6 +296,8 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
return err return err
} }
am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
return am.updateAccountPeers(account) return am.updateAccountPeers(account)
} }

View File

@@ -0,0 +1,164 @@
package server
import (
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
)
const (
groupAdminUserID = "testingAdminUser"
)
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
am, err := createManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestGroupAccount(am)
if err != nil {
t.Error("failed to init testing account")
}
testCases := []struct {
name string
groupID string
expectedReason string
}{
{
"route",
"grp-for-route",
"route",
},
{
"name server groups",
"grp-for-name-server-grp",
"name server groups",
},
{
"policy",
"grp-for-policies",
"policy",
},
{
"setup keys",
"grp-for-keys",
"setup key",
},
{
"users",
"grp-for-users",
"user",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
err = am.DeleteGroup(account.Id, "", testCase.groupID)
if err == nil {
t.Errorf("delete %s group successfully", testCase.groupID)
return
}
gErr, ok := err.(*GroupLinkError)
if !ok {
t.Error("invalid error type")
return
}
if gErr.Resource != testCase.expectedReason {
t.Errorf("invalid error case: %s, expected: %s", gErr.Resource, testCase.expectedReason)
}
})
}
}
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
accountID := "testingAcc"
domain := "example.com"
groupForRoute := &Group{
"grp-for-route",
"Group for route",
GroupIssuedAPI,
make([]string, 0),
}
groupForNameServerGroups := &Group{
"grp-for-name-server-grp",
"Group for name server groups",
GroupIssuedAPI,
make([]string, 0),
}
groupForPolicies := &Group{
"grp-for-policies",
"Group for policies",
GroupIssuedAPI,
make([]string, 0),
}
groupForSetupKeys := &Group{
"grp-for-keys",
"Group for setup keys",
GroupIssuedAPI,
make([]string, 0),
}
groupForUsers := &Group{
"grp-for-users",
"Group for users",
GroupIssuedAPI,
make([]string, 0),
}
routeResource := &route.Route{
ID: "example route",
Groups: []string{groupForRoute.ID},
}
nameServerGroup := &nbdns.NameServerGroup{
ID: "example name server group",
Groups: []string{groupForNameServerGroups.ID},
}
policy := &Policy{
ID: "example policy",
Rules: []*PolicyRule{
{
ID: "example policy rule",
Destinations: []string{groupForPolicies.ID},
},
},
}
setupKey := &SetupKey{
Id: "example setup key",
AutoGroups: []string{groupForSetupKeys.ID},
}
user := &User{
Id: "example user",
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
return am.Store.GetAccount(account.Id)
}

View File

@@ -72,10 +72,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
return return
} }
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, &server.Settings{ settings := &server.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
}) }
if req.Settings.JwtGroupsEnabled != nil {
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
}
if req.Settings.JwtGroupsClaimName != nil {
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@@ -93,6 +102,8 @@ func toAccountResponse(account *server.Account) *api.Account {
Settings: api.AccountSettings{ Settings: api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
}, },
} }
} }

View File

@@ -58,6 +58,9 @@ func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user")
sr := func(v string) *string { return &v }
br := func(v bool) *bool { return &v }
handler := initAccountsTestData(&server.Account{ handler := initAccountsTestData(&server.Account{
Id: accountID, Id: accountID,
Domain: "hotmail.com", Domain: "hotmail.com",
@@ -91,6 +94,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedSettings: api.AccountSettings{ expectedSettings: api.AccountSettings{
PeerLoginExpiration: int(time.Hour.Seconds()), PeerLoginExpiration: int(time.Hour.Seconds()),
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false),
}, },
expectedArray: true, expectedArray: true,
expectedID: accountID, expectedID: accountID,
@@ -105,6 +110,24 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedSettings: api.AccountSettings{ expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000, PeerLoginExpiration: 15552000,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
},
{
name: "PutAccount OK wiht JWT",
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
PeerLoginExpirationEnabled: false,
JwtGroupsClaimName: sr("roles"),
JwtGroupsEnabled: br(true),
}, },
expectedArray: false, expectedArray: false,
expectedID: accountID, expectedID: accountID,

View File

@@ -54,6 +54,14 @@ components:
description: Period of time after which peer login expires (seconds). description: Period of time after which peer login expires (seconds).
type: integer type: integer
example: 43200 example: 43200
jwt_groups_enabled:
description: Allows extract groups from JWT claim and add it to account groups.
type: boolean
example: true
jwt_groups_claim_name:
description: Name of the claim from which we extract groups names to add it to account groups.
type: string
example: "roles"
required: required:
- peer_login_expiration_enabled - peer_login_expiration_enabled
- peer_login_expiration - peer_login_expiration
@@ -462,6 +470,10 @@ components:
description: Count of peers associated to the group description: Count of peers associated to the group
type: integer type: integer
example: 2 example: 2
issued:
description: How group was issued by API or from JWT token
type: string
example: api
required: required:
- id - id
- name - name
@@ -1262,6 +1274,33 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
/api/users/{userId}/invite:
post:
summary: Resend user invitation
description: Resend user invitation
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The unique identifier of a user
responses:
'200':
description: Invite status code
content: {}
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers: /api/peers:
get: get:
summary: List all Peers summary: List all Peers

View File

@@ -129,6 +129,12 @@ type AccountRequest struct {
// AccountSettings defines model for AccountSettings. // AccountSettings defines model for AccountSettings.
type AccountSettings struct { type AccountSettings struct {
// JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups.
JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"`
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
// PeerLoginExpiration Period of time after which peer login expires (seconds). // PeerLoginExpiration Period of time after which peer login expires (seconds).
PeerLoginExpiration int `json:"peer_login_expiration"` PeerLoginExpiration int `json:"peer_login_expiration"`
@@ -174,6 +180,9 @@ type Group struct {
// Id Group ID // Id Group ID
Id string `json:"id"` Id string `json:"id"`
// Issued How group was issued by API or from JWT token
Issued *string `json:"issued,omitempty"`
// Name Group Name identifier // Name Group Name identifier
Name string `json:"name"` Name string `json:"name"`
@@ -189,6 +198,9 @@ type GroupMinimum struct {
// Id Group ID // Id Group ID
Id string `json:"id"` Id string `json:"id"`
// Issued How group was issued by API or from JWT token
Issued *string `json:"issued,omitempty"`
// Name Group Name identifier // Name Group Name identifier
Name string `json:"name"` Name string `json:"name"`

View File

@@ -72,7 +72,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
_, ok = account.Groups[groupID] eg, ok := account.Groups[groupID]
if !ok { if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return return
@@ -110,6 +110,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
ID: groupID, ID: groupID,
Name: req.Name, Name: req.Name,
Peers: peers, Peers: peers,
Issued: eg.Issued,
} }
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil { if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
@@ -152,6 +153,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
ID: xid.New().String(), ID: xid.New().String(),
Name: req.Name, Name: req.Name,
Peers: peers, Peers: peers,
Issued: server.GroupIssuedAPI,
} }
err = h.accountManager.SaveGroup(account.Id, user.Id, &group) err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
@@ -166,7 +168,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
// DeleteGroup handles group deletion request // DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@@ -190,8 +192,13 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteGroup(aID, groupID) err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
if err != nil { if err != nil {
_, ok := err.(*server.GroupLinkError)
if ok {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
util.WriteError(err, w) util.WriteError(err, w)
return return
} }
@@ -237,6 +244,7 @@ func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
PeersCount: len(group.Peers), PeersCount: len(group.Peers),
Issued: &group.Issued,
} }
for _, pid := range group.Peers { for _, pid := range group.Peers {

View File

@@ -11,17 +11,15 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
) )
var TestPeers = map[string]*server.Peer{ var TestPeers = map[string]*server.Peer{
@@ -42,9 +40,17 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
if groupID != "idofthegroup" { if groupID != "idofthegroup" {
return nil, status.Errorf(status.NotFound, "not found") return nil, status.Errorf(status.NotFound, "not found")
} }
if groupID == "id-jwt-group" {
return &server.Group{
ID: "id-jwt-group",
Name: "Default Group",
Issued: server.GroupIssuedJWT,
}, nil
}
return &server.Group{ return &server.Group{
ID: "idofthegroup", ID: "idofthegroup",
Name: "Group", Name: "Group",
Issued: server.GroupIssuedAPI,
}, nil }, nil
}, },
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) { UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
@@ -80,11 +86,24 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
user.Id: user, user.Id: user,
}, },
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT},
"id-all": {ID: "id-all", Name: "All"}, "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI},
"id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI},
}, },
}, user, nil }, user, nil
}, },
DeleteGroupFunc: func(accountID, userId, groupID string) error {
if groupID == "linked-grp" {
return &server.GroupLinkError{
Resource: "something",
Name: "linked-grp",
}
}
if groupID == "invalid-grp" {
return fmt.Errorf("internal error")
}
return nil
},
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
@@ -169,6 +188,8 @@ func TestGetGroup(t *testing.T) {
} }
func TestWriteGroup(t *testing.T) { func TestWriteGroup(t *testing.T) {
groupIssuedAPI := "api"
groupIssuedJWT := "jwt"
tt := []struct { tt := []struct {
name string name string
expectedStatus int expectedStatus int
@@ -189,6 +210,7 @@ func TestWriteGroup(t *testing.T) {
expectedGroup: &api.Group{ expectedGroup: &api.Group{
Id: "id-was-set", Id: "id-was-set",
Name: "Default POSTed Group", Name: "Default POSTed Group",
Issued: &groupIssuedAPI,
}, },
}, },
{ {
@@ -210,6 +232,7 @@ func TestWriteGroup(t *testing.T) {
expectedGroup: &api.Group{ expectedGroup: &api.Group{
Id: "id-existed", Id: "id-existed",
Name: "Default POSTed Group", Name: "Default POSTed Group",
Issued: &groupIssuedAPI,
}, },
}, },
{ {
@@ -230,6 +253,19 @@ func TestWriteGroup(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{
name: "Write Group PUT not not change Issue",
requestType: http.MethodPut,
requestPath: "/api/groups/id-jwt-group",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"changed","Issued":"api"}`)),
expectedStatus: http.StatusOK,
expectedGroup: &api.Group{
Id: "id-jwt-group",
Name: "changed",
Issued: &groupIssuedJWT,
},
},
} }
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user")
@@ -271,3 +307,79 @@ func TestWriteGroup(t *testing.T) {
}) })
} }
} }
func TestDeleteGroup(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
}{
{
name: "Try to delete linked group",
requestType: http.MethodDelete,
requestPath: "/api/groups/linked-grp",
expectedStatus: http.StatusBadRequest,
expectedBody: true,
},
{
name: "Try to cause internal error",
requestType: http.MethodDelete,
requestPath: "/api/groups/invalid-grp",
expectedStatus: http.StatusInternalServerError,
expectedBody: true,
},
{
name: "Try to cause internal error",
requestType: http.MethodDelete,
requestPath: "/api/groups/invalid-grp",
expectedStatus: http.StatusInternalServerError,
expectedBody: true,
},
{
name: "Delete group",
requestType: http.MethodDelete,
requestPath: "/api/groups/any-grp",
expectedStatus: http.StatusOK,
expectedBody: false,
},
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if tc.expectedBody {
got := &util.ErrorResponse{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, got.Code, tc.expectedStatus)
}
})
}
}

View File

@@ -113,6 +113,7 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS") apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS")
} }
func (apiHandler *apiHandler) addUsersTokensEndpoint() { func (apiHandler *apiHandler) addUsersTokensEndpoint() {

View File

@@ -57,7 +57,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
case "bearer": case "bearer":
err := m.CheckJWTFromRequest(w, r) err := m.CheckJWTFromRequest(w, r)
if err != nil { if err != nil {
log.Debugf("Error when validating JWT claims: %s", err.Error()) log.Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return return
} }

View File

@@ -208,6 +208,37 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(w, users) util.WriteJSONObject(w, users)
} }
// InviteUser resend invitations to users who haven't activated their accounts,
// prior to the expiration period.
func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
}
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
autoGroups := user.AutoGroups autoGroups := user.AutoGroups
if autoGroups == nil { if autoGroups == nil {

View File

@@ -98,6 +98,17 @@ func initUsersTestData() *UsersHandler {
} }
return info, nil return info, nil
}, },
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
if initiatorUserID != existingUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
}
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
return nil
},
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
@@ -340,6 +351,51 @@ func TestCreateUser(t *testing.T) {
} }
} }
func TestInviteUser(t *testing.T) {
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestVars map[string]string
}{
{
name: "Invite User with Existing User",
requestType: http.MethodPost,
requestPath: "/api/users/" + existingUserID + "/invite",
expectedStatus: http.StatusOK,
requestVars: map[string]string{"userId": existingUserID},
},
{
name: "Invite User with missing user_id",
requestType: http.MethodPost,
requestPath: "/api/users/" + notFoundUserID + "/invite",
expectedStatus: http.StatusNotFound,
requestVars: map[string]string{"userId": notFoundUserID},
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
rr := httptest.NewRecorder()
userHandler.InviteUser(rr, req)
res := rr.Result()
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
}
})
}
}
func TestDeleteUser(t *testing.T) { func TestDeleteUser(t *testing.T) {
tt := []struct { tt := []struct {
name string name string

View File

@@ -13,6 +13,11 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
type ErrorResponse struct {
Message string `json:"message"`
Code int `json:"code"`
}
// WriteJSONObject simply writes object to the HTTP reponse in JSON format // WriteJSONObject simply writes object to the HTTP reponse in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) { func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -58,14 +63,9 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
// WriteErrorResponse prepares and writes an error response i nJSON // WriteErrorResponse prepares and writes an error response i nJSON
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) { func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
type errorResponse struct {
Message string `json:"message"`
Code int `json:"code"`
}
w.WriteHeader(httpStatus) w.WriteHeader(httpStatus)
w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(&errorResponse{ err := json.NewEncoder(w).Encode(&ErrorResponse{
Message: errMsg, Message: errMsg,
Code: httpStatus, Code: httpStatus,
}) })

View File

@@ -98,6 +98,11 @@ type userExportJobStatusResponse struct {
ID string `json:"id"` ID string `json:"id"`
} }
// userVerificationJobRequest is a user verification request struct
type userVerificationJobRequest struct {
UserID string `json:"user_id"`
}
// auth0Profile represents an Auth0 user profile response // auth0Profile represents an Auth0 user profile response
type auth0Profile struct { type auth0Profile struct {
AccountID string `json:"wt_account_id"` AccountID string `json:"wt_account_id"`
@@ -689,6 +694,48 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
return &createResp, nil return &createResp, nil
} }
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *Auth0Manager) InviteUserByID(userID string) error {
userVerificationReq := userVerificationJobRequest{
UserID: userID,
}
payload, err := am.helper.Marshal(userVerificationReq)
if err != nil {
return err
}
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload))
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("error while closing invite user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
}
return nil
}
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink // If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {

View File

@@ -440,6 +440,12 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
return users, nil return users, nil
} }
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AuthentikManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
func (am *AuthentikManager) authenticationContext() (context.Context, error) { func (am *AuthentikManager) authenticationContext() (context.Context, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
if err != nil { if err != nil {

View File

@@ -448,6 +448,12 @@ func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
return nil return nil
} }
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AzureManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) { func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
q := url.Values{} q := url.Values{}
q.Add("$select", extensionFields) q.Add("$select", extensionFields)

View File

@@ -0,0 +1,350 @@
package idp
import (
"context"
"encoding/base64"
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi"
"google.golang.org/api/option"
"net/http"
"strings"
"time"
)
// GoogleWorkspaceManager Google Workspace manager client instance.
type GoogleWorkspaceManager struct {
usersService *admin.UsersService
CustomerID string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// GoogleWorkspaceClientConfig Google Workspace manager client configurations.
type GoogleWorkspaceClientConfig struct {
ServiceAccountKey string
CustomerID string
}
// GoogleWorkspaceCredentials Google Workspace authentication information.
type GoogleWorkspaceCredentials struct {
clientConfig GoogleWorkspaceClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) {
return JWTToken{}, nil
}
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.CustomerID == "" {
return nil, fmt.Errorf("google IdP configuration is incomplete, CustomerID is missing")
}
credentials := &GoogleWorkspaceCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
// Create a new Admin SDK Directory service client
adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey)
if err != nil {
return nil, err
}
service, err := admin.NewService(context.Background(),
option.WithScopes(admin.AdminDirectoryUserScope, admin.AdminDirectoryUserschemaScope),
option.WithCredentials(adminCredentials),
)
if err != nil {
return nil, err
}
if err = configureAppMetadataSchema(service, config.CustomerID); err != nil {
return nil, err
}
return &GoogleWorkspaceManager{
usersService: service.Users,
CustomerID: config.CustomerID,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
metadata, err := gm.helper.Marshal(appMetadata)
if err != nil {
return err
}
user := &admin.User{
CustomSchemas: map[string]googleapi.RawMessage{
"app_metadata": metadata,
},
}
_, err = gm.usersService.Update(userID, user).Do()
if err != nil {
return err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
return nil
}
// GetUserDataByID requests user data from Google Workspace via ID.
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := gm.usersService.Get(userID).Projection("full").Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
return parseGoogleWorkspaceUser(user)
}
// GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
query := fmt.Sprintf("app_metadata.wt_account_id=\"%s\"", accountID)
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Query(query).Projection("full").Do()
if err != nil {
return nil, err
}
usersData := make([]*UserData, 0)
for _, user := range usersList.Users {
userData, err := parseGoogleWorkspaceUser(user)
if err != nil {
return nil, err
}
usersData = append(usersData, userData)
}
return usersData, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, user := range usersList.Users {
userData, err := parseGoogleWorkspaceUser(user)
if err != nil {
return nil, err
}
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
}
return indexedUsers, nil
}
// CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
invite := true
metadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
username := &admin.UserName{}
fields := strings.Fields(name)
if n := len(fields); n > 0 {
username.GivenName = strings.Join(fields[:n-1], " ")
username.FamilyName = fields[n-1]
}
payload, err := gm.helper.Marshal(metadata)
if err != nil {
return nil, err
}
user := &admin.User{
Name: username,
PrimaryEmail: email,
CustomSchemas: map[string]googleapi.RawMessage{
"app_metadata": payload,
},
Password: GeneratePassword(8, 1, 1, 1),
}
user, err = gm.usersService.Insert(user).Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountCreateUser()
}
return parseGoogleWorkspaceUser(user)
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
user, err := gm.usersService.Get(email).Projection("full").Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
userData, err := parseGoogleWorkspaceUser(user)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
users = append(users, userData)
return users, nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
// If that fails, it falls back to using the default Google credentials path.
// It returns the retrieved credentials or an error if unsuccessful.
func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) {
log.Debug("retrieving google credentials from the base64 encoded service account key")
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
if err != nil {
return nil, fmt.Errorf("failed to decode service account key: %w", err)
}
creds, err := google.CredentialsFromJSON(
context.Background(),
decodeKey,
admin.AdminDirectoryUserschemaScope,
admin.AdminDirectoryUserScope,
)
if err == nil {
return creds, err
}
log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
log.Debug("falling back to default google credentials location")
creds, err = google.FindDefaultCredentials(
context.Background(),
admin.AdminDirectoryUserschemaScope,
admin.AdminDirectoryUserScope,
)
if err != nil {
return nil, err
}
return creds, nil
}
// configureAppMetadataSchema create a custom schema for managing app metadata fields in Google Workspace.
func configureAppMetadataSchema(service *admin.Service, customerID string) error {
schemaList, err := service.Schemas.List(customerID).Do()
if err != nil {
return err
}
// checks if app_metadata schema is already created
for _, schema := range schemaList.Schemas {
if schema.SchemaName == "app_metadata" {
return nil
}
}
// create new app_metadata schema
appMetadataSchema := &admin.Schema{
SchemaName: "app_metadata",
Fields: []*admin.SchemaFieldSpec{
{
FieldName: "wt_account_id",
FieldType: "STRING",
MultiValued: false,
},
{
FieldName: "wt_pending_invite",
FieldType: "BOOL",
MultiValued: false,
},
},
}
_, err = service.Schemas.Insert(customerID, appMetadataSchema).Do()
if err != nil {
return err
}
return nil
}
// parseGoogleWorkspaceUser parse google user to UserData.
func parseGoogleWorkspaceUser(user *admin.User) (*UserData, error) {
var appMetadata AppMetadata
// Get app metadata from custom schemas
if user.CustomSchemas != nil {
rawMessage := user.CustomSchemas["app_metadata"]
helper := JsonParser{}
if err := helper.Unmarshal(rawMessage, &appMetadata); err != nil {
return nil, err
}
}
return &UserData{
ID: user.Id,
Email: user.PrimaryEmail,
Name: user.Name.FullName,
AppMetadata: appMetadata,
}, nil
}

View File

@@ -17,6 +17,7 @@ type Manager interface {
GetAllAccounts() (map[string][]*UserData, error) GetAllAccounts() (map[string][]*UserData, error)
CreateUser(email string, name string, accountID string) (*UserData, error) CreateUser(email string, name string, accountID string) (*UserData, error)
GetUserByEmail(email string) ([]*UserData, error) GetUserByEmail(email string) ([]*UserData, error)
InviteUserByID(userID string) error
} }
// ClientConfig defines common client configuration for all IdP manager // ClientConfig defines common client configuration for all IdP manager
@@ -162,6 +163,12 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
APIToken: config.ExtraConfig["ApiToken"], APIToken: config.ExtraConfig["ApiToken"],
} }
return NewOktaManager(oktaClientConfig, appMetrics) return NewOktaManager(oktaClientConfig, appMetrics)
case "google":
googleClientConfig := GoogleWorkspaceClientConfig{
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
CustomerID: config.ExtraConfig["CustomerId"],
}
return NewGoogleWorkspaceManager(googleClientConfig, appMetrics)
default: default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)

View File

@@ -461,6 +461,12 @@ func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppM
return nil return nil
} }
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (km *KeycloakManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
attrs := keycloakUserAttributes{} attrs := keycloakUserAttributes{}
attrs.Set(wtAccountID, appMetadata.WTAccountID) attrs.Set(wtAccountID, appMetadata.WTAccountID)

Some files were not shown because too many files have changed in this diff Show More