mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 07:06:41 +00:00
Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6dee89379b | ||
|
|
76db4f801a | ||
|
|
6c2ed4b4f2 | ||
|
|
2541c78dd0 | ||
|
|
97b6e79809 | ||
|
|
6ad3847615 | ||
|
|
a4d830ef83 | ||
|
|
9e540cd5b4 | ||
|
|
3027d8f27e | ||
|
|
e69ec6ab6a | ||
|
|
7ddde41c92 | ||
|
|
7ebe58f20a | ||
|
|
9c2c0e7934 | ||
|
|
c6af1037d9 | ||
|
|
5cb9a126f1 | ||
|
|
f40951cdf5 | ||
|
|
6e264d9de7 | ||
|
|
42db9773f4 | ||
|
|
bb9f6f6d0a | ||
|
|
829ce6573e | ||
|
|
a366d9e208 | ||
|
|
e074c24487 | ||
|
|
54fe05f6d8 | ||
|
|
33a155d9aa | ||
|
|
51878659f8 | ||
|
|
c000c05435 | ||
|
|
b39ffef22c | ||
|
|
d96f882acb | ||
|
|
d409219b51 | ||
|
|
8b619a8224 | ||
|
|
ed075bc9b9 | ||
|
|
8eb098d6fd | ||
|
|
68a8687c80 | ||
|
|
f7d97b02fd | ||
|
|
2691e729cd | ||
|
|
b524a9d49d | ||
|
|
774d8e955c | ||
|
|
c20f98c8b6 | ||
|
|
20ae540fb1 | ||
|
|
58cfa2bb17 | ||
|
|
06005cc10e | ||
|
|
1a3e377304 | ||
|
|
dd29f4c01e | ||
|
|
cb7ecd1cc4 | ||
|
|
b5d8142705 | ||
|
|
f45eb1a1da | ||
|
|
2567006412 | ||
|
|
b92107efc8 | ||
|
|
5d19811331 | ||
|
|
697d41c94e | ||
|
|
75d541f967 | ||
|
|
7dfbb71f7a | ||
|
|
a5d14c92ff | ||
|
|
ce091ab42b | ||
|
|
d2fad1cfd9 | ||
|
|
0b5594f145 | ||
|
|
9beaa91db9 | ||
|
|
c8b4c08139 | ||
|
|
dad5501a44 | ||
|
|
1ced2462c1 | ||
|
|
64adaeb276 | ||
|
|
6e26d03fb8 | ||
|
|
493ddb4fe3 | ||
|
|
75fac258e7 | ||
|
|
bc8ee8fc3c | ||
|
|
3724323f76 | ||
|
|
3ef33874b1 |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -116,7 +116,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
|
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||||
- name: Install rsrc
|
- name: Install rsrc
|
||||||
run: go install github.com/akavel/rsrc@v0.10.2
|
run: go install github.com/akavel/rsrc@v0.10.2
|
||||||
- name: Generate windows rsrc
|
- name: Generate windows rsrc
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ jobs:
|
|||||||
CI_NETBIRD_MGMT_IDP: "none"
|
CI_NETBIRD_MGMT_IDP: "none"
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||||
|
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||||
|
|
||||||
- name: check values
|
- name: check values
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
|
|||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -7,8 +7,15 @@ bin/
|
|||||||
conf.json
|
conf.json
|
||||||
http-cmds.sh
|
http-cmds.sh
|
||||||
infrastructure_files/management.json
|
infrastructure_files/management.json
|
||||||
|
infrastructure_files/management-*.json
|
||||||
infrastructure_files/docker-compose.yml
|
infrastructure_files/docker-compose.yml
|
||||||
|
infrastructure_files/openid-configuration.json
|
||||||
|
infrastructure_files/turnserver.conf
|
||||||
|
management/management
|
||||||
|
client/client
|
||||||
|
client/client.exe
|
||||||
*.syso
|
*.syso
|
||||||
client/.distfiles/
|
client/.distfiles/
|
||||||
infrastructure_files/setup.env
|
infrastructure_files/setup.env
|
||||||
|
infrastructure_files/setup-*.env
|
||||||
.vscode
|
.vscode
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ builds:
|
|||||||
- amd64
|
- amd64
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
tags:
|
||||||
|
- legacy_appindicator
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
- id: netbird-ui-windows
|
- id: netbird-ui-windows
|
||||||
@@ -55,9 +57,6 @@ nfpms:
|
|||||||
- src: client/ui/disconnected.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- libayatana-appindicator3-1
|
|
||||||
- libgtk-3-dev
|
|
||||||
- libappindicator3-dev
|
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
@@ -75,9 +74,6 @@ nfpms:
|
|||||||
- src: client/ui/disconnected.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- libayatana-appindicator3-1
|
|
||||||
- libgtk-3-dev
|
|
||||||
- libappindicator3-dev
|
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
uploads:
|
uploads:
|
||||||
|
|||||||
@@ -70,9 +70,9 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
|||||||
|
|
||||||
### Start using NetBird
|
### Start using NetBird
|
||||||
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
||||||
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
|
- See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started).
|
||||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
|
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide).
|
||||||
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
|
- Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms.
|
||||||
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
||||||
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
|||||||
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
|
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||||
|
|
||||||
### Roadmap
|
### Roadmap
|
||||||
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -35,6 +36,11 @@ type RouteListener interface {
|
|||||||
routemanager.RouteListener
|
routemanager.RouteListener
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DnsReadyListener export internal dns ReadyListener for mobile
|
||||||
|
type DnsReadyListener interface {
|
||||||
|
dns.ReadyListener
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
@@ -49,6 +55,7 @@ type Client struct {
|
|||||||
ctxCancelLock *sync.Mutex
|
ctxCancelLock *sync.Mutex
|
||||||
deviceName string
|
deviceName string
|
||||||
routeListener routemanager.RouteListener
|
routeListener routemanager.RouteListener
|
||||||
|
onHostDnsFn func([]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -65,7 +72,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener) error {
|
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
@@ -90,7 +97,8 @@ func (c *Client) Run(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
|
c.onHostDnsFn = func([]string) {}
|
||||||
|
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -126,6 +134,17 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
return &PeerInfoArray{items: peerInfos}
|
return &PeerInfoArray{items: peerInfos}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||||
|
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||||
|
dnsServer, err := dns.GetServerDns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer.OnUpdatedHostDNSServer(list.items)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetConnectionListener set the network connection listener
|
// SetConnectionListener set the network connection listener
|
||||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||||
c.recorder.SetConnectionListener(listener)
|
c.recorder.SetConnectionListener(listener)
|
||||||
|
|||||||
26
client/android/dns_list.go
Normal file
26
client/android/dns_list.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// DNSList is a wrapper of []string
|
||||||
|
type DNSList struct {
|
||||||
|
items []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new DNS address to the collection
|
||||||
|
func (array *DNSList) Add(s string) {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get return an element of the collection
|
||||||
|
func (array *DNSList) Get(i int) (string, error) {
|
||||||
|
if i >= len(array.items) || i < 0 {
|
||||||
|
return "", fmt.Errorf("out of range")
|
||||||
|
}
|
||||||
|
return array.items[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size return with the size of the collection
|
||||||
|
func (array *DNSList) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
||||||
24
client/android/dns_list_test.go
Normal file
24
client/android/dns_list_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDNSList_Get(t *testing.T) {
|
||||||
|
l := DNSList{
|
||||||
|
items: make([]string, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := l.Get(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("invalid error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = l.Get(-1)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error but got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = l.Get(1)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error but got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Print(err)
|
log.Debug(err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
|||||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Error: %v\n", err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
cmd.Printf("Couldn't connect. " +
|
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||||
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
"You can verify the connection by running:\n\n" +
|
||||||
"Run the status command: \n\n" +
|
" netbird status\n\n")
|
||||||
" netbird status\n\n" +
|
return err
|
||||||
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
|
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ type Manager interface {
|
|||||||
dPort *Port,
|
dPort *Port,
|
||||||
direction RuleDirection,
|
direction RuleDirection,
|
||||||
action Action,
|
action Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (Rule, error)
|
) (Rule, error)
|
||||||
|
|
||||||
@@ -60,5 +61,8 @@ type Manager interface {
|
|||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset() error
|
||||||
|
|
||||||
|
// Flush the changes to firewall controller
|
||||||
|
Flush() error
|
||||||
|
|
||||||
// TODO: migrate routemanager firewal actions to this interface
|
// TODO: migrate routemanager firewal actions to this interface
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/nadoo/ipset"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
@@ -35,6 +36,8 @@ type Manager struct {
|
|||||||
inputDefaultRuleSpecs []string
|
inputDefaultRuleSpecs []string
|
||||||
outputDefaultRuleSpecs []string
|
outputDefaultRuleSpecs []string
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
|
rulesets map[string]ruleset
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@@ -43,6 +46,11 @@ type iFaceMapper interface {
|
|||||||
Address() iface.WGAddress
|
Address() iface.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ruleset struct {
|
||||||
|
rule *Rule
|
||||||
|
ips map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@@ -51,6 +59,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
||||||
outputDefaultRuleSpecs: []string{
|
outputDefaultRuleSpecs: []string{
|
||||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
||||||
|
rulesets: make(map[string]ruleset),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Init(); err != nil {
|
||||||
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// init clients for booth ipv4 and ipv6
|
// init clients for booth ipv4 and ipv6
|
||||||
@@ -58,14 +71,18 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||||
}
|
}
|
||||||
|
if isIptablesClientAvailable(ipv4Client) {
|
||||||
m.ipv4Client = ipv4Client
|
m.ipv4Client = ipv4Client
|
||||||
|
}
|
||||||
|
|
||||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
||||||
} else {
|
} else {
|
||||||
|
if isIptablesClientAvailable(ipv6Client) {
|
||||||
m.ipv6Client = ipv6Client
|
m.ipv6Client = ipv6Client
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.Reset(); err != nil {
|
if err := m.Reset(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to reset firewall: %v", err)
|
return nil, fmt.Errorf("failed to reset firewall: %v", err)
|
||||||
@@ -73,6 +90,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||||
|
_, err := client.ListChains("filter")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment is empty rule ID is used as comment
|
// If comment is empty rule ID is used as comment
|
||||||
@@ -83,6 +105,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@@ -101,22 +124,45 @@ func (m *Manager) AddFiltering(
|
|||||||
if sPort != nil && sPort.Values != nil {
|
if sPort != nil && sPort.Values != nil {
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||||
}
|
}
|
||||||
|
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleID := uuid.New().String()
|
||||||
if comment == "" {
|
if comment == "" {
|
||||||
comment = ruleID
|
comment = ruleID
|
||||||
}
|
}
|
||||||
|
|
||||||
specs := m.filterRuleSpecs(
|
if ipsetName != "" {
|
||||||
"filter",
|
rs, rsExists := m.rulesets[ipsetName]
|
||||||
ip,
|
if !rsExists {
|
||||||
string(protocol),
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
sPortVal,
|
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
||||||
dPortVal,
|
}
|
||||||
direction,
|
if err := ipset.Create(ipsetName); err != nil {
|
||||||
action,
|
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||||
comment,
|
}
|
||||||
)
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rsExists {
|
||||||
|
// if ruleset already exists it means we already have the firewall rule
|
||||||
|
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||||
|
rs.ips[ip.String()] = ruleID
|
||||||
|
return &Rule{
|
||||||
|
ruleID: ruleID,
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
|
dst: direction == fw.RuleDirectionOUT,
|
||||||
|
v6: ip.To4() == nil,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// this is new ipset so we need to create firewall rule for it
|
||||||
|
}
|
||||||
|
|
||||||
|
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
|
||||||
|
direction, action, comment, ipsetName)
|
||||||
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
||||||
@@ -144,12 +190,24 @@ func (m *Manager) AddFiltering(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Rule{
|
rule := &Rule{
|
||||||
id: ruleID,
|
ruleID: ruleID,
|
||||||
specs: specs,
|
specs: specs,
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
dst: direction == fw.RuleDirectionOUT,
|
dst: direction == fw.RuleDirectionOUT,
|
||||||
v6: ip.To4() == nil,
|
v6: ip.To4() == nil,
|
||||||
}, nil
|
}
|
||||||
|
if ipsetName != "" {
|
||||||
|
// ipset name is defined and it means that this rule was created
|
||||||
|
// for it, need to assosiate it with ruleset
|
||||||
|
m.rulesets[ipsetName] = ruleset{
|
||||||
|
rule: rule,
|
||||||
|
ips: map[string]string{rule.ip: ruleID},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
@@ -170,6 +228,31 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
client = m.ipv6Client
|
client = m.ipv6Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
||||||
|
// delete IP from ruleset IPs list and ipset
|
||||||
|
if _, ok := rs.ips[r.ip]; ok {
|
||||||
|
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||||
|
}
|
||||||
|
delete(rs.ips, r.ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if after delete, set still contains other IPs,
|
||||||
|
// no need to delete firewall rule and we should exit here
|
||||||
|
if len(rs.ips) != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// we delete last IP from the set, that means we need to delete
|
||||||
|
// set itself and assosiated firewall rule too
|
||||||
|
delete(m.rulesets, r.ipsetName)
|
||||||
|
|
||||||
|
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||||
|
log.Errorf("delete empty ipset: %v", err)
|
||||||
|
}
|
||||||
|
r = rs.rule
|
||||||
|
}
|
||||||
|
|
||||||
if r.dst {
|
if r.dst {
|
||||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
||||||
}
|
}
|
||||||
@@ -193,6 +276,9 @@ func (m *Manager) Reset() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// reset firewall chain, clear it and drop it
|
// reset firewall chain, clear it and drop it
|
||||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||||
@@ -233,6 +319,16 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for ipsetName := range m.rulesets {
|
||||||
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
if err := ipset.Destroy(ipsetName); err != nil {
|
||||||
|
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
delete(m.rulesets, ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,6 +336,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
|||||||
func (m *Manager) filterRuleSpecs(
|
func (m *Manager) filterRuleSpecs(
|
||||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
table string, ip net.IP, protocol string, sPort, dPort string,
|
||||||
direction fw.RuleDirection, action fw.Action, comment string,
|
direction fw.RuleDirection, action fw.Action, comment string,
|
||||||
|
ipsetName string,
|
||||||
) (specs []string) {
|
) (specs []string) {
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// don't use IP matching if IP is ip 0.0.0.0
|
||||||
@@ -249,13 +346,21 @@ func (m *Manager) filterRuleSpecs(
|
|||||||
switch direction {
|
switch direction {
|
||||||
case fw.RuleDirectionIN:
|
case fw.RuleDirectionIN:
|
||||||
if matchByIP {
|
if matchByIP {
|
||||||
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||||
|
} else {
|
||||||
specs = append(specs, "-s", ip.String())
|
specs = append(specs, "-s", ip.String())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
case fw.RuleDirectionOUT:
|
case fw.RuleDirectionOUT:
|
||||||
if matchByIP {
|
if matchByIP {
|
||||||
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||||
|
} else {
|
||||||
specs = append(specs, "-d", ip.String())
|
specs = append(specs, "-d", ip.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if protocol != "all" {
|
if protocol != "all" {
|
||||||
specs = append(specs, "-p", protocol)
|
specs = append(specs, "-p", protocol)
|
||||||
}
|
}
|
||||||
@@ -335,3 +440,16 @@ func (m *Manager) actionToStr(action fw.Action) string {
|
|||||||
}
|
}
|
||||||
return "DROP"
|
return "DROP"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||||
|
if ipsetName == "" {
|
||||||
|
return ""
|
||||||
|
} else if sPort != "" && dPort != "" {
|
||||||
|
return ipsetName + "-sport-dport"
|
||||||
|
} else if sPort != "" {
|
||||||
|
return ipsetName + "-sport"
|
||||||
|
} else if dPort != "" {
|
||||||
|
return ipsetName + "-dport"
|
||||||
|
}
|
||||||
|
return ipsetName
|
||||||
|
}
|
||||||
|
|||||||
@@ -55,12 +55,13 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
err := manager.Reset()
|
||||||
t.Errorf("clear the manager state: %v", err)
|
require.NoError(t, err, "clear the manager state")
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -68,7 +69,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
t.Run("add first rule", func(t *testing.T) {
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||||
@@ -81,33 +82,31 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
Values: []int{8043: 8046},
|
Values: []int{8043: 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
rule2, err = manager.AddFiltering(
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
if err := manager.DeleteRule(rule1); err != nil {
|
err := manager.DeleteRule(rule1)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
}
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
if err := manager.DeleteRule(rule2); err != nil {
|
err := manager.DeleteRule(rule2)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
}
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...)
|
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Values: []int{5353}}
|
||||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
|
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@@ -122,6 +121,88 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIptablesManagerIPSet(t *testing.T) {
|
||||||
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mock := &iFaceMock{
|
||||||
|
NameFunc: func() string {
|
||||||
|
return "lo"
|
||||||
|
},
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// just check on the local interface
|
||||||
|
manager, err := Create(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := manager.Reset()
|
||||||
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var rule1 fw.Rule
|
||||||
|
t.Run("add first rule with set", func(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.20.0.2")
|
||||||
|
port := &fw.Port{Values: []int{8080}}
|
||||||
|
rule1, err = manager.AddFiltering(
|
||||||
|
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||||
|
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||||
|
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||||
|
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||||
|
})
|
||||||
|
|
||||||
|
var rule2 fw.Rule
|
||||||
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.20.0.3")
|
||||||
|
port := &fw.Port{
|
||||||
|
Values: []int{443},
|
||||||
|
}
|
||||||
|
rule2, err = manager.AddFiltering(
|
||||||
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||||
|
"default", "accept HTTPS traffic from ports range",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
|
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
|
err := manager.DeleteRule(rule1)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
|
err := manager.DeleteRule(rule2)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reset check", func(t *testing.T) {
|
||||||
|
err = manager.Reset()
|
||||||
|
require.NoError(t, err, "failed to reset")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||||
require.NoError(t, err, "failed to check rule")
|
require.NoError(t, err, "failed to check rule")
|
||||||
@@ -153,9 +234,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
err := manager.Reset()
|
||||||
t.Errorf("clear the manager state: %v", err)
|
require.NoError(t, err, "clear the manager state")
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -167,9 +248,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package iptables
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
id string
|
ruleID string
|
||||||
|
ipsetName string
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
|
ip string
|
||||||
dst bool
|
dst bool
|
||||||
v6 bool
|
v6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/google/uuid"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
@@ -29,11 +31,14 @@ const (
|
|||||||
FilterOutputChainName = "netbird-acl-output-filter"
|
FilterOutputChainName = "netbird-acl-output-filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
conn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
|
sConn *nftables.Conn
|
||||||
tableIPv4 *nftables.Table
|
tableIPv4 *nftables.Table
|
||||||
tableIPv6 *nftables.Table
|
tableIPv6 *nftables.Table
|
||||||
|
|
||||||
@@ -43,6 +48,10 @@ type Manager struct {
|
|||||||
filterInputChainIPv6 *nftables.Chain
|
filterInputChainIPv6 *nftables.Chain
|
||||||
filterOutputChainIPv6 *nftables.Chain
|
filterOutputChainIPv6 *nftables.Chain
|
||||||
|
|
||||||
|
rulesetManager *rulesetManager
|
||||||
|
setRemovedIPs map[string]struct{}
|
||||||
|
setRemoved map[string]*nftables.Set
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,8 +63,23 @@ type iFaceMapper interface {
|
|||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
|
// and is permanent. Using same connection for booth type of operations
|
||||||
|
// overloads netlink with high amount of rules ( > 10000)
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
conn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
|
sConn: sConn,
|
||||||
|
|
||||||
|
rulesetManager: newRuleManager(),
|
||||||
|
setRemovedIPs: map[string]struct{}{},
|
||||||
|
setRemoved: map[string]*nftables.Set{},
|
||||||
|
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +101,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@@ -84,6 +109,7 @@ func (m *Manager) AddFiltering(
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
|
ipset *nftables.Set
|
||||||
table *nftables.Table
|
table *nftables.Table
|
||||||
chain *nftables.Chain
|
chain *nftables.Chain
|
||||||
)
|
)
|
||||||
@@ -107,6 +133,46 @@ func (m *Manager) AddFiltering(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rawIP := ip.To4()
|
||||||
|
if rawIP == nil {
|
||||||
|
rawIP = ip.To16()
|
||||||
|
}
|
||||||
|
|
||||||
|
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||||
|
|
||||||
|
if ipsetName != "" {
|
||||||
|
// if we already have set with given name, just add ip to the set
|
||||||
|
// and return rule with new ID in other case let's create rule
|
||||||
|
// with fresh created set and set element
|
||||||
|
|
||||||
|
var isSetNew bool
|
||||||
|
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
||||||
|
if err != nil {
|
||||||
|
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||||
|
return nil, fmt.Errorf("get set name: %v", err)
|
||||||
|
}
|
||||||
|
isSetNew = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||||
|
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSetNew {
|
||||||
|
// if we already have nftables rules with set for given direction
|
||||||
|
// just add new rule to the ruleset and return new fw.Rule object
|
||||||
|
|
||||||
|
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||||
|
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||||
|
}
|
||||||
|
// if ipset exists but it is not linked to rule for given direction
|
||||||
|
// create new rule for direction and bind ipset to it later
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
ifaceKey := expr.MetaKeyIIFNAME
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
ifaceKey = expr.MetaKeyOIFNAME
|
||||||
@@ -146,39 +212,47 @@ func (m *Manager) AddFiltering(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||||
if s := ip.String(); s != "0.0.0.0" && s != "::" {
|
// in that case not add IP match expression into the rule definition
|
||||||
|
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||||
// source address position
|
// source address position
|
||||||
var adrLen, adrOffset uint32
|
addrLen := uint32(len(rawIP))
|
||||||
if ip.To4() == nil {
|
addrOffset := uint32(12)
|
||||||
adrLen = 16
|
if addrLen == 16 {
|
||||||
adrOffset = 8
|
addrOffset = 8
|
||||||
} else {
|
|
||||||
adrLen = 4
|
|
||||||
adrOffset = 12
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// change to destination address position if need
|
// change to destination address position if need
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
adrOffset += adrLen
|
addrOffset += addrLen
|
||||||
}
|
}
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: adrOffset,
|
Offset: addrOffset,
|
||||||
Len: adrLen,
|
Len: addrLen,
|
||||||
},
|
},
|
||||||
|
)
|
||||||
|
// add individual IP for match if no ipset defined
|
||||||
|
if ipset == nil {
|
||||||
|
expressions = append(expressions,
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: add.AsSlice(),
|
Data: rawIP,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
} else {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ipsetName,
|
||||||
|
SetID: ipset.ID,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
if sPort != nil && len(sPort.Values) != 0 {
|
||||||
@@ -219,39 +293,76 @@ func (m *Manager) AddFiltering(
|
|||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
id := uuid.New().String()
|
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||||
userData := []byte(strings.Join([]string{id, comment}, " "))
|
|
||||||
|
|
||||||
_ = m.conn.InsertRule(&nftables.Rule{
|
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Position: 0,
|
Position: 0,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
})
|
})
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
if err := m.conn.Flush(); err != nil {
|
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
list, err := m.conn.GetRules(table, chain)
|
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||||
if err != nil {
|
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the rule to the chain
|
// getRulesetID returns ruleset ID based on given parameters
|
||||||
rule := &Rule{id: id}
|
func (m *Manager) getRulesetID(
|
||||||
for _, r := range list {
|
ip net.IP,
|
||||||
if bytes.Equal(r.UserData, userData) {
|
proto fw.Protocol,
|
||||||
rule.Rule = r
|
sPort *fw.Port,
|
||||||
break
|
dPort *fw.Port,
|
||||||
|
direction fw.RuleDirection,
|
||||||
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) string {
|
||||||
|
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||||
|
if sPort != nil {
|
||||||
|
rulesetID += sPort.String()
|
||||||
}
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
if dPort != nil {
|
||||||
|
rulesetID += dPort.String()
|
||||||
}
|
}
|
||||||
if rule.Rule == nil {
|
rulesetID += ":"
|
||||||
return nil, fmt.Errorf("rule not found")
|
rulesetID += strconv.Itoa(int(action))
|
||||||
|
if ipsetName == "" {
|
||||||
|
return "ip:" + ip.String() + rulesetID
|
||||||
|
}
|
||||||
|
return "set:" + ipsetName + rulesetID
|
||||||
}
|
}
|
||||||
|
|
||||||
return rule, nil
|
// createSet in given table by name
|
||||||
|
func (m *Manager) createSet(
|
||||||
|
table *nftables.Table,
|
||||||
|
rawIP []byte,
|
||||||
|
name string,
|
||||||
|
) (*nftables.Set, error) {
|
||||||
|
keyType := nftables.TypeIPAddr
|
||||||
|
if len(rawIP) == 16 {
|
||||||
|
keyType = nftables.TypeIP6Addr
|
||||||
|
}
|
||||||
|
// else we create new ipset and continue creating rule
|
||||||
|
ipset := &nftables.Set{
|
||||||
|
Name: name,
|
||||||
|
Table: table,
|
||||||
|
Dynamic: true,
|
||||||
|
KeyType: keyType,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||||
|
return nil, fmt.Errorf("create set: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush created set: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// chain returns the chain for the given IP address with specific settings
|
// chain returns the chain for the given IP address with specific settings
|
||||||
@@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||||
tables, err := m.conn.ListTablesOfFamily(family)
|
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
return nil, fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
@@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
|
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return table, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createChainIfNotExists(
|
func (m *Manager) createChainIfNotExists(
|
||||||
@@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
chains, err := m.conn.ListChainsOfTableFamily(family)
|
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of chains: %w", err)
|
return nil, fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
@@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
Policy: &polAccept,
|
Policy: &polAccept,
|
||||||
}
|
}
|
||||||
|
|
||||||
chain = m.conn.AddChain(chain)
|
chain = m.rConn.AddChain(chain)
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
ifaceKey := expr.MetaKeyIIFNAME
|
||||||
shiftDSTAddr := 0
|
shiftDSTAddr := 0
|
||||||
@@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = m.conn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
@@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
},
|
},
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||||
}
|
}
|
||||||
_ = m.conn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
})
|
})
|
||||||
if err := m.conn.Flush(); err != nil {
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
nativeRule, ok := rule.(*Rule)
|
nativeRule, ok := rule.(*Rule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
|
if nativeRule.nftRule == nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.Flush()
|
if nativeRule.nftSet != nil {
|
||||||
|
// call twice of delete set element raises error
|
||||||
|
// so we need to check if element is already removed
|
||||||
|
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||||
|
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||||
|
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.setRemovedIPs[key] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.rulesetManager.deleteRule(nativeRule) {
|
||||||
|
// deleteRule indicates that we still have IP in the ruleset
|
||||||
|
// it means we should not remove the nftables rule but need to update set
|
||||||
|
// so we prepare IP to be removed from set on the next flush call
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||||
|
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
nativeRule.nftRule = nil
|
||||||
|
|
||||||
|
if nativeRule.nftSet != nil {
|
||||||
|
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||||
|
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||||
|
}
|
||||||
|
nativeRule.nftSet = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@@ -475,27 +633,116 @@ func (m *Manager) Reset() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
chains, err := m.conn.ListChains()
|
chains, err := m.rConn.ListChains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||||
m.conn.DelChain(c)
|
m.rConn.DelChain(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tables, err := m.conn.ListTables()
|
tables, err := m.rConn.ListTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
return fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == FilterTableName {
|
if t.Name == FilterTableName {
|
||||||
m.conn.DelTable(t)
|
m.rConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.Flush()
|
return m.rConn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush rule/chain/set operations from the buffer
|
||||||
|
//
|
||||||
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
func (m *Manager) Flush() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set must be removed after flush rule changes
|
||||||
|
// otherwise we will get error
|
||||||
|
for _, s := range m.setRemoved {
|
||||||
|
m.rConn.FlushSet(s)
|
||||||
|
m.rConn.DelSet(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.setRemoved) > 0 {
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.setRemovedIPs = map[string]struct{}{}
|
||||||
|
m.setRemoved = map[string]*nftables.Set{}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) flushWithBackoff() (err error) {
|
||||||
|
backoff := 4
|
||||||
|
backoffTime := 1000 * time.Millisecond
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("failed to flush nftables, retrying...")
|
||||||
|
if i == backoff-1 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(backoffTime)
|
||||||
|
backoffTime = backoffTime * 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
if table == nil || chain == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := m.rConn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range list {
|
||||||
|
if len(rule.UserData) != 0 {
|
||||||
|
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||||
|
log.Errorf("failed to set rule handle: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodePort(port fw.Port) []byte {
|
func encodePort(port fw.Port) []byte {
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
fw.RuleDirectionIN,
|
fw.RuleDirectionIN,
|
||||||
fw.ActionDrop,
|
fw.ActionDrop,
|
||||||
"",
|
"",
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
// test expectations:
|
// test expectations:
|
||||||
// 1) regular rule
|
// 1) regular rule
|
||||||
// 2) "accept extra routed traffic rule" for the interface
|
// 2) "accept extra routed traffic rule" for the interface
|
||||||
@@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
err = manager.DeleteRule(rule)
|
err = manager.DeleteRule(rule)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
// test expectations:
|
// test expectations:
|
||||||
@@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
if err := manager.Reset(); err != nil {
|
||||||
@@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
|
}
|
||||||
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
if i%100 == 0 {
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
}
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ import (
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
*nftables.Rule
|
nftRule *nftables.Rule
|
||||||
id string
|
nftSet *nftables.Set
|
||||||
|
|
||||||
|
ruleID string
|
||||||
|
ip []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
115
client/firewall/nftables/ruleset_linux.go
Normal file
115
client/firewall/nftables/ruleset_linux.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||||
|
type nftRuleset struct {
|
||||||
|
nftRule *nftables.Rule
|
||||||
|
nftSet *nftables.Set
|
||||||
|
issuedRules map[string]*Rule
|
||||||
|
rulesetID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type rulesetManager struct {
|
||||||
|
rulesets map[string]*nftRuleset
|
||||||
|
|
||||||
|
nftSetName2rulesetID map[string]string
|
||||||
|
issuedRuleID2rulesetID map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRuleManager() *rulesetManager {
|
||||||
|
return &rulesetManager{
|
||||||
|
rulesets: map[string]*nftRuleset{},
|
||||||
|
|
||||||
|
nftSetName2rulesetID: map[string]string{},
|
||||||
|
issuedRuleID2rulesetID: map[string]string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||||
|
ruleset, ok := r.rulesets[rulesetID]
|
||||||
|
return ruleset, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) createRuleset(
|
||||||
|
rulesetID string,
|
||||||
|
nftRule *nftables.Rule,
|
||||||
|
nftSet *nftables.Set,
|
||||||
|
) *nftRuleset {
|
||||||
|
ruleset := nftRuleset{
|
||||||
|
rulesetID: rulesetID,
|
||||||
|
nftRule: nftRule,
|
||||||
|
nftSet: nftSet,
|
||||||
|
issuedRules: map[string]*Rule{},
|
||||||
|
}
|
||||||
|
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||||
|
if nftSet != nil {
|
||||||
|
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||||
|
}
|
||||||
|
return &ruleset
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) addRule(
|
||||||
|
ruleset *nftRuleset,
|
||||||
|
ip []byte,
|
||||||
|
) (*Rule, error) {
|
||||||
|
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||||
|
return nil, fmt.Errorf("ruleset not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
nftRule: ruleset.nftRule,
|
||||||
|
nftSet: ruleset.nftSet,
|
||||||
|
ruleID: xid.New().String(),
|
||||||
|
ip: ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleset.issuedRules[rule.ruleID] = &rule
|
||||||
|
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteRule from ruleset and returns true if contains other rules
|
||||||
|
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||||
|
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleset := r.rulesets[rulesetID]
|
||||||
|
if ruleset.nftRule == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||||
|
delete(ruleset.issuedRules, rule.ruleID)
|
||||||
|
|
||||||
|
if len(ruleset.issuedRules) == 0 {
|
||||||
|
delete(r.rulesets, ruleset.rulesetID)
|
||||||
|
if rule.nftSet != nil {
|
||||||
|
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||||
|
//
|
||||||
|
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||||
|
// we set correct handle value to it.
|
||||||
|
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||||
|
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||||
|
ruleset, ok := r.rulesets[string(split[0])]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("ruleset not found")
|
||||||
|
}
|
||||||
|
*ruleset.nftRule = *nftRule
|
||||||
|
return nil
|
||||||
|
}
|
||||||
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{
|
||||||
|
UserData: []byte(rulesetID),
|
||||||
|
}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||||
|
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||||
|
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_addRule(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.1.1")
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||||
|
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||||
|
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||||
|
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||||
|
|
||||||
|
ruleset2 := &nftRuleset{
|
||||||
|
rulesetID: "ruleset-2",
|
||||||
|
}
|
||||||
|
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||||
|
require.Error(t, err, "addRule() should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.1.1")
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
|
||||||
|
ip2 := []byte("192.168.1.1")
|
||||||
|
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule2, "rule should not be nil")
|
||||||
|
|
||||||
|
hasNext := rulesetManager.deleteRule(rule)
|
||||||
|
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||||
|
|
||||||
|
// Check that the rule is no longer in the manager.
|
||||||
|
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||||
|
|
||||||
|
hasNext = rulesetManager.deleteRule(rule2)
|
||||||
|
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.0.1")
|
||||||
|
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
|
||||||
|
nftRuleCopy := nftRule
|
||||||
|
nftRuleCopy.Handle = 2
|
||||||
|
nftRuleCopy.UserData = []byte(rulesetID)
|
||||||
|
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||||
|
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||||
|
// check correct work with references
|
||||||
|
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
nftSet := nftables.Set{
|
||||||
|
ID: 2,
|
||||||
|
}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||||
|
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||||
|
|
||||||
|
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||||
|
require.True(t, ok, "getRuleset() failed")
|
||||||
|
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||||
|
|
||||||
|
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||||
|
require.False(t, ok, "getRuleset() failed")
|
||||||
|
}
|
||||||
@@ -1,5 +1,9 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
// Protocol is the protocol of the port
|
// Protocol is the protocol of the port
|
||||||
type Protocol string
|
type Protocol string
|
||||||
|
|
||||||
@@ -28,3 +32,15 @@ type Port struct {
|
|||||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||||
Values []int
|
Values []int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String interface implementation
|
||||||
|
func (p *Port) String() string {
|
||||||
|
var ports string
|
||||||
|
for _, port := range p.Values {
|
||||||
|
if ports != "" {
|
||||||
|
ports += ","
|
||||||
|
}
|
||||||
|
ports += strconv.Itoa(port)
|
||||||
|
}
|
||||||
|
return ports
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,11 +21,13 @@ type IFaceMapper interface {
|
|||||||
SetFilter(iface.PacketFilter) error
|
SetFilter(iface.PacketFilter) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RuleSet is a set of rules grouped by a string key
|
||||||
|
type RuleSet map[string]Rule
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
outgoingRules []Rule
|
outgoingRules map[string]RuleSet
|
||||||
incomingRules []Rule
|
incomingRules map[string]RuleSet
|
||||||
rulesIndex map[string]int
|
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
|
|
||||||
@@ -48,7 +50,6 @@ type decoder struct {
|
|||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface IFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
rulesIndex: make(map[string]int),
|
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
outgoingRules: make(map[string]RuleSet),
|
||||||
|
incomingRules: make(map[string]RuleSet),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
@@ -81,6 +84,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
r := Rule{
|
r := Rule{
|
||||||
@@ -124,15 +128,17 @@ func (m *Manager) AddFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
var p int
|
|
||||||
if direction == fw.RuleDirectionIN {
|
if direction == fw.RuleDirectionIN {
|
||||||
m.incomingRules = append(m.incomingRules, r)
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
p = len(m.incomingRules) - 1
|
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||||
} else {
|
}
|
||||||
m.outgoingRules = append(m.outgoingRules, r)
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
p = len(m.outgoingRules) - 1
|
} else {
|
||||||
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
|
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
m.rulesIndex[r.id] = p
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return &r, nil
|
return &r, nil
|
||||||
@@ -148,24 +154,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, ok := m.rulesIndex[r.id]
|
if r.direction == fw.RuleDirectionIN {
|
||||||
|
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
}
|
}
|
||||||
delete(m.rulesIndex, r.id)
|
delete(m.incomingRules[r.ip.String()], r.id)
|
||||||
|
|
||||||
var toUpdate []Rule
|
|
||||||
if r.direction == fw.RuleDirectionIN {
|
|
||||||
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
|
|
||||||
toUpdate = m.incomingRules
|
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
|
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||||
toUpdate = m.outgoingRules
|
if !ok {
|
||||||
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
|
}
|
||||||
|
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < len(toUpdate); i++ {
|
|
||||||
m.rulesIndex[toUpdate[i].id] = i
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,13 +176,15 @@ func (m *Manager) Reset() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = m.outgoingRules[:0]
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
m.incomingRules = m.incomingRules[:0]
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
m.rulesIndex = make(map[string]int)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
return m.dropFilter(packetData, m.outgoingRules, false)
|
||||||
@@ -192,7 +196,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter imlements same logic for booth direction of the traffic
|
// dropFilter imlements same logic for booth direction of the traffic
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
|
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@@ -224,37 +228,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
payloadLayer := d.decoded[1]
|
|
||||||
|
|
||||||
// check if IP address match by IP
|
var ip net.IP
|
||||||
for _, rule := range rules {
|
|
||||||
if rule.matchByIP {
|
|
||||||
switch ipLayer {
|
switch ipLayer {
|
||||||
case layers.LayerTypeIPv4:
|
case layers.LayerTypeIPv4:
|
||||||
if isIncomingPacket {
|
if isIncomingPacket {
|
||||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
ip = d.ip4.SrcIP
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
ip = d.ip4.DstIP
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case layers.LayerTypeIPv6:
|
case layers.LayerTypeIPv6:
|
||||||
if isIncomingPacket {
|
if isIncomingPacket {
|
||||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
ip = d.ip6.SrcIP
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
ip = d.ip6.DstIP
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// default policy is DROP ALL
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||||
|
payloadLayer := d.decoded[1]
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
if rule.protoLayer == layerTypeAll {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
if payloadLayer != rule.protoLayer {
|
||||||
@@ -264,38 +280,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
// we ignore rule.drop and call this hook
|
// we ignore rule.drop and call this hook
|
||||||
if rule.udpHook != nil {
|
if rule.udpHook != nil {
|
||||||
return rule.udpHook(packetData)
|
return rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false, false
|
||||||
// default policy is DROP ALL
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
@@ -325,19 +339,19 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
var toUpdate []Rule
|
|
||||||
if in {
|
if in {
|
||||||
r.direction = fw.RuleDirectionIN
|
r.direction = fw.RuleDirectionIN
|
||||||
m.incomingRules = append([]Rule{r}, m.incomingRules...)
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
toUpdate = m.incomingRules
|
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
toUpdate = m.outgoingRules
|
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range toUpdate {
|
|
||||||
m.rulesIndex[toUpdate[i].id] = i
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return r.id
|
return r.id
|
||||||
@@ -345,15 +359,19 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
// RemovePacketHook removes packet hook by given ID
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||||
for _, r := range m.incomingRules {
|
for _, arr := range m.incomingRules {
|
||||||
|
for _, r := range arr {
|
||||||
if r.id == hookID {
|
if r.id == hookID {
|
||||||
return m.DeleteRule(&r)
|
return m.DeleteRule(&r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, r := range m.outgoingRules {
|
}
|
||||||
|
for _, arr := range m.outgoingRules {
|
||||||
|
for _, r := range arr {
|
||||||
if r.id == hookID {
|
if r.id == hookID {
|
||||||
return m.DeleteRule(&r)
|
return m.DeleteRule(&r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action = fw.ActionDrop
|
action = fw.ActionDrop
|
||||||
comment = "Test rule 2"
|
comment = "Test rule 2"
|
||||||
|
|
||||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
|
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the rulesIndex")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.DeleteRule(rule2)
|
err = m.DeleteRule(rule2)
|
||||||
@@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
|
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||||
t.Errorf("rule1 still in the rulesIndex")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
incomingRules: []Rule{},
|
incomingRules: map[string]RuleSet{},
|
||||||
outgoingRules: []Rule{},
|
outgoingRules: map[string]RuleSet{},
|
||||||
rulesIndex: make(map[string]int),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule Rule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules) != 1 {
|
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addedRule = manager.incomingRules[0]
|
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(manager.outgoingRules) != 1 {
|
if len(manager.outgoingRules) != 1 {
|
||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addedRule = manager.outgoingRules[0]
|
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if !tt.ip.Equal(addedRule.ip) {
|
||||||
@@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected udpHook to be set")
|
t.Errorf("expected udpHook to be set")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure rulesIndex is correctly updated
|
|
||||||
index, ok := manager.rulesIndex[addedRule.id]
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("expected rule to be in rulesIndex")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if index != 0 {
|
|
||||||
t.Errorf("expected rule index to be 0, got %d", index)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -244,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||||
t.Errorf("rules is not empty")
|
t.Errorf("rules is not empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -282,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
|
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -346,12 +338,14 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
for _, rule := range manager.outgoingRules {
|
for _, arr := range manager.outgoingRules {
|
||||||
|
for _, rule := range arr {
|
||||||
if rule.id == hookID {
|
if rule.id == hookID {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
t.Fatalf("The hook was not added properly.")
|
t.Fatalf("The hook was not added properly.")
|
||||||
@@ -364,12 +358,14 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||||
for _, rule := range manager.outgoingRules {
|
for _, arr := range manager.outgoingRules {
|
||||||
|
for _, rule := range arr {
|
||||||
if rule.id == hookID {
|
if rule.id == hookID {
|
||||||
t.Fatalf("The hook was not removed properly.")
|
t.Fatalf("The hook was not removed properly.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
@@ -394,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -31,10 +34,23 @@ type Manager interface {
|
|||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
manager firewall.Manager
|
manager firewall.Manager
|
||||||
|
ipsetCounter int
|
||||||
rulesPairs map[string][]firewall.Rule
|
rulesPairs map[string][]firewall.Rule
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ipsetInfo struct {
|
||||||
|
name string
|
||||||
|
ipCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
|
return &DefaultManager{
|
||||||
|
manager: fm,
|
||||||
|
rulesPairs: make(map[string][]firewall.Rule),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||||
//
|
//
|
||||||
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
|
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
|
||||||
@@ -42,11 +58,28 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.mutex.Unlock()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
total := 0
|
||||||
|
for _, pairs := range d.rulesPairs {
|
||||||
|
total += len(pairs)
|
||||||
|
}
|
||||||
|
log.Infof(
|
||||||
|
"ACL rules processed in: %v, total rules count: %d",
|
||||||
|
time.Since(start), total)
|
||||||
|
}()
|
||||||
|
|
||||||
if d.manager == nil {
|
if d.manager == nil {
|
||||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := d.manager.Flush(); err != nil {
|
||||||
|
log.Error("failed to flush firewall rules: ", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
enableSSH := (networkMap.PeerConfig != nil &&
|
enableSSH := (networkMap.PeerConfig != nil &&
|
||||||
@@ -94,14 +127,38 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
|
|
||||||
applyFailed := false
|
applyFailed := false
|
||||||
newRulePairs := make(map[string][]firewall.Rule)
|
newRulePairs := make(map[string][]firewall.Rule)
|
||||||
|
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||||
|
|
||||||
|
// calculate which IP's can be grouped in by which ipset
|
||||||
|
// to do that we use rule selector (which is just rule properties without IP's)
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
rulePair, err := d.protoRuleToFirewallRule(r)
|
selector := d.getRuleGroupingSelector(r)
|
||||||
|
ipset, ok := ipsetByRuleSelectors[selector]
|
||||||
|
if !ok {
|
||||||
|
ipset = &ipsetInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ipset.ipCount++
|
||||||
|
ipsetByRuleSelectors[selector] = ipset
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||||
|
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||||
|
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||||
|
ipsetName := ""
|
||||||
|
if ipset.name == "" {
|
||||||
|
d.ipsetCounter++
|
||||||
|
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||||
|
}
|
||||||
|
ipsetName = ipset.name
|
||||||
|
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||||
applyFailed = true
|
applyFailed = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
newRulePairs[rulePair[0].GetRuleID()] = rulePair
|
newRulePairs[pairID] = rulePair
|
||||||
}
|
}
|
||||||
if applyFailed {
|
if applyFailed {
|
||||||
log.Error("failed to apply firewall rules, rollback ACL to previous state")
|
log.Error("failed to apply firewall rules, rollback ACL to previous state")
|
||||||
@@ -140,55 +197,71 @@ func (d *DefaultManager) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) ([]firewall.Rule, error) {
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
|
r *mgmProto.FirewallRule,
|
||||||
|
ipsetName string,
|
||||||
|
) (string, []firewall.Rule, error) {
|
||||||
ip := net.ParseIP(r.PeerIP)
|
ip := net.ParseIP(r.PeerIP)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol := convertToFirewallProtocol(r.Protocol)
|
protocol := convertToFirewallProtocol(r.Protocol)
|
||||||
if protocol == firewall.ProtocolUnknown {
|
if protocol == firewall.ProtocolUnknown {
|
||||||
return nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
action := convertFirewallAction(r.Action)
|
action := convertFirewallAction(r.Action)
|
||||||
if action == firewall.ActionUnknown {
|
if action == firewall.ActionUnknown {
|
||||||
return nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
||||||
}
|
}
|
||||||
|
|
||||||
var port *firewall.Port
|
var port *firewall.Port
|
||||||
if r.Port != "" {
|
if r.Port != "" {
|
||||||
value, err := strconv.Atoi(r.Port)
|
value, err := strconv.Atoi(r.Port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid port, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||||
}
|
}
|
||||||
port = &firewall.Port{
|
port = &firewall.Port{
|
||||||
Values: []int{value},
|
Values: []int{value},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||||
|
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
||||||
|
return ruleID, rulesPair, nil
|
||||||
|
}
|
||||||
|
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
var err error
|
var err error
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.FirewallRule_IN:
|
case mgmProto.FirewallRule_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, "")
|
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||||
case mgmProto.FirewallRule_OUT:
|
case mgmProto.FirewallRule_OUT:
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, "")
|
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.rulesPairs[rules[0].GetRuleID()] = rules
|
d.rulesPairs[ruleID] = rules
|
||||||
return rules, nil
|
return ruleID, rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
func (d *DefaultManager) addInRules(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
comment string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment)
|
rule, err := d.manager.AddFiltering(
|
||||||
|
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -198,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment)
|
rule, err = d.manager.AddFiltering(
|
||||||
|
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -206,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return append(rules, rule), nil
|
return append(rules, rule), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
func (d *DefaultManager) addOutRules(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
comment string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment)
|
rule, err := d.manager.AddFiltering(
|
||||||
|
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -218,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment)
|
rule, err = d.manager.AddFiltering(
|
||||||
|
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -226,6 +309,23 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return append(rules, rule), nil
|
return append(rules, rule), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||||
|
func (d *DefaultManager) getRuleID(
|
||||||
|
ip net.IP,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
direction int,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
comment string,
|
||||||
|
) string {
|
||||||
|
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||||
|
if port != nil {
|
||||||
|
idStr += port.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
|
||||||
|
}
|
||||||
|
|
||||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
||||||
//
|
//
|
||||||
@@ -235,7 +335,7 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
networkMap *mgmProto.NetworkMap,
|
networkMap *mgmProto.NetworkMap,
|
||||||
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
||||||
totalIPs := 0
|
totalIPs := 0
|
||||||
for _, p := range networkMap.RemotePeers {
|
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||||
for range p.AllowedIps {
|
for range p.AllowedIps {
|
||||||
totalIPs++
|
totalIPs++
|
||||||
}
|
}
|
||||||
@@ -246,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
in := protoMatch{}
|
in := protoMatch{}
|
||||||
out := protoMatch{}
|
out := protoMatch{}
|
||||||
|
|
||||||
|
// trace which type of protocols was squashed
|
||||||
|
squashedRules := []*mgmProto.FirewallRule{}
|
||||||
|
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||||
// But we zeroed the IP's for protocol if:
|
// But we zeroed the IP's for protocol if:
|
||||||
@@ -262,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = map[string]int{}
|
||||||
}
|
}
|
||||||
match := protocols[r.Protocol]
|
|
||||||
|
|
||||||
if _, ok := match[r.PeerIP]; ok {
|
// special case, when we recieve this all network IP address
|
||||||
|
// it means that rules for that protocol was already optimized on the
|
||||||
|
// management side
|
||||||
|
if r.PeerIP == "0.0.0.0" {
|
||||||
|
squashedRules = append(squashedRules, r)
|
||||||
|
squashedProtocols[r.Protocol] = struct{}{}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
match[r.PeerIP] = i
|
|
||||||
|
ipset := protocols[r.Protocol]
|
||||||
|
|
||||||
|
if _, ok := ipset[r.PeerIP]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipset[r.PeerIP] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
for i, r := range networkMap.FirewallRules {
|
||||||
@@ -288,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
mgmProto.FirewallRule_UDP,
|
mgmProto.FirewallRule_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
|
||||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||||
for _, protocol := range protocolOrders {
|
for _, protocol := range protocolOrders {
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||||
@@ -346,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
return append(rules, squashedRules...), squashedProtocols
|
return append(rules, squashedRules...), squashedProtocols
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||||
|
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||||
|
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||||
|
}
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case mgmProto.FirewallRule_TCP:
|
case mgmProto.FirewallRule_TCP:
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &DefaultManager{
|
return newDefaultManager(fm), nil
|
||||||
manager: fm,
|
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DefaultManager{
|
return newDefaultManager(fm), nil
|
||||||
manager: fm,
|
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,6 +55,11 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add extra rules", func(t *testing.T) {
|
t.Run("add extra rules", func(t *testing.T) {
|
||||||
|
existedPairs := map[string]struct{}{}
|
||||||
|
for id := range acl.rulesPairs {
|
||||||
|
existedPairs[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
// remove first rule
|
// remove first rule
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[1:]
|
networkMap.FirewallRules = networkMap.FirewallRules[1:]
|
||||||
networkMap.FirewallRules = append(
|
networkMap.FirewallRules = append(
|
||||||
@@ -67,11 +72,6 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
existedRulesID := map[string]struct{}{}
|
|
||||||
for id := range acl.rulesPairs {
|
|
||||||
existedRulesID[id] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
|
|
||||||
// we should have one old and one new rule in the existed rules
|
// we should have one old and one new rule in the existed rules
|
||||||
@@ -80,13 +80,16 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check that old rules was removed
|
// check that old rule was removed
|
||||||
for id := range existedRulesID {
|
previousCount := 0
|
||||||
if _, ok := acl.rulesPairs[id]; ok {
|
for id := range acl.rulesPairs {
|
||||||
t.Errorf("old rule was not removed")
|
if _, ok := existedPairs[id]; ok {
|
||||||
return
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if previousCount != 1 {
|
||||||
|
t.Errorf("old rule was not removed")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handle default rules", func(t *testing.T) {
|
t.Run("handle default rules", func(t *testing.T) {
|
||||||
|
|||||||
@@ -215,11 +215,13 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||||
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
if *input.PreSharedKey != "" {
|
||||||
|
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||||
*input.PreSharedKey, config.PreSharedKey)
|
*input.PreSharedKey, config.PreSharedKey)
|
||||||
config.PreSharedKey = *input.PreSharedKey
|
config.PreSharedKey = *input.PreSharedKey
|
||||||
refresh = true
|
refresh = true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if config.SSHKey == "" {
|
if config.SSHKey == "" {
|
||||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
|
|||||||
@@ -63,7 +63,22 @@ func TestGetConfig(t *testing.T) {
|
|||||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||||
|
|
||||||
// case 4: existing config, but new managementURL has been provided -> update config
|
// case 4: new empty pre-shared key config -> fetch it
|
||||||
|
newPreSharedKey := ""
|
||||||
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ManagementURL: managementURL,
|
||||||
|
AdminURL: adminURL,
|
||||||
|
ConfigPath: path,
|
||||||
|
PreSharedKey: &newPreSharedKey,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||||
|
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||||
|
|
||||||
|
// case 5: existing config, but new managementURL has been provided -> update config
|
||||||
newManagementURL := "https://test.newManagement.url:33071"
|
newManagementURL := "https://test.newManagement.url:33071"
|
||||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
ManagementURL: newManagementURL,
|
ManagementURL: newManagementURL,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -24,7 +25,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||||
|
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunClientMobile with main logic on mobile system
|
||||||
|
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
||||||
|
// in case of non Android os these variables will be nil
|
||||||
|
mobileDependency := MobileDependency{
|
||||||
|
TunAdapter: tunAdapter,
|
||||||
|
IFaceDiscover: iFaceDiscover,
|
||||||
|
RouteListener: routeListener,
|
||||||
|
HostDNSAddresses: dnsAddresses,
|
||||||
|
DnsReadyListener: dnsReadyListener,
|
||||||
|
}
|
||||||
|
return runClient(ctx, config, statusRecorder, mobileDependency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// in case of non Android os these variables will be nil
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
||||||
md := MobileDependency{
|
|
||||||
TunAdapter: tunAdapter,
|
|
||||||
IFaceDiscover: iFaceDiscover,
|
|
||||||
RouteListener: routeListener,
|
|
||||||
}
|
|
||||||
|
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
|
||||||
err = engine.Start()
|
err = engine.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
type androidHostManager struct {
|
type androidHostManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
return &androidHostManager{}, nil
|
return &androidHostManager{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -34,7 +32,7 @@ type systemConfigurator struct {
|
|||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(_ *iface.WGIface) (hostManager, error) {
|
func newHostManager(_ WGIface) (hostManager, error) {
|
||||||
return &systemConfigurator{
|
return &systemConfigurator{
|
||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,7 +25,7 @@ const (
|
|||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -33,7 +31,7 @@ type registryConfigurator struct {
|
|||||||
existingSearchDomains []string
|
existingSearchDomains []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
if m.UpdateDNSServerFunc != nil {
|
if m.UpdateDNSServerFunc != nil {
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ import (
|
|||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const resolvconfCommand = "resolvconf"
|
const resolvconfCommand = "resolvconf"
|
||||||
@@ -18,7 +16,7 @@ type resolvconf struct {
|
|||||||
ifaceName string
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
return &resolvconf{
|
return &resolvconf{
|
||||||
ifaceName: wgInterface.Name(),
|
ifaceName: wgInterface.Name(),
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@@ -3,29 +3,20 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
defaultPort = 53
|
type ReadyListener interface {
|
||||||
customPort = 5053
|
OnReady()
|
||||||
defaultIP = "127.0.0.1"
|
}
|
||||||
customIP = "127.0.0.153"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
@@ -33,6 +24,7 @@ type Server interface {
|
|||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
|
OnUpdatedHostDNSServer(strings []string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
type registeredHandlerMap map[string]handlerWithStop
|
||||||
@@ -42,21 +34,19 @@ type DefaultServer struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
fakeResolverWG sync.WaitGroup
|
service service
|
||||||
server *dns.Server
|
|
||||||
dnsMux *dns.ServeMux
|
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface *iface.WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
listenerIsRunning bool
|
|
||||||
runtimePort int
|
|
||||||
runtimeIP string
|
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig hostDNSConfig
|
currentConfig hostDNSConfig
|
||||||
customAddress *netip.AddrPort
|
|
||||||
enabled bool
|
// permanent related properties
|
||||||
|
permanent bool
|
||||||
|
hostsDnsList []string
|
||||||
|
hostsDnsListLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
@@ -70,9 +60,7 @@ type muxUpdate struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
|
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
@@ -82,34 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
|||||||
addrPort = &parsedAddrPort
|
addrPort = &parsedAddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, stop := context.WithCancel(ctx)
|
var dnsService service
|
||||||
|
if wgInterface.IsUserspaceBind() {
|
||||||
|
dnsService = newServiceViaMemory(wgInterface)
|
||||||
|
} else {
|
||||||
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
|
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
|
||||||
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
|
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||||
|
ds.permanent = true
|
||||||
|
ds.hostsDnsList = hostsDnsList
|
||||||
|
ds.addHostRootZone()
|
||||||
|
setServerDns(ds)
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||||
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
server: &dns.Server{
|
service: dnsService,
|
||||||
Net: "udp",
|
|
||||||
Handler: mux,
|
|
||||||
UDPSize: 65535,
|
|
||||||
},
|
|
||||||
dnsMux: mux,
|
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
customAddress: addrPort,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if initialDnsCfg != nil {
|
return defaultServer
|
||||||
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultServer.evalRuntimeAddress()
|
// Initialize instantiate host manager and the dns service
|
||||||
return defaultServer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize instantiate host manager. It required to be initialized wginterface
|
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
func (s *DefaultServer) Initialize() (err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@@ -118,74 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.permanent {
|
||||||
|
err = s.service.Listen()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.hostManager, err = newHostManager(s.wgInterface)
|
s.hostManager, err = newHostManager(s.wgInterface)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen runs the listener in a go routine
|
// DnsIP returns the DNS resolver server IP address
|
||||||
func (s *DefaultServer) listen() {
|
//
|
||||||
// nil check required in unit tests
|
// When kernel space interface used it return real DNS server listener IP address
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
||||||
s.fakeResolverWG.Add(1)
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
hookID := s.filterDNSTraffic()
|
|
||||||
s.fakeResolverWG.Wait()
|
|
||||||
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
|
|
||||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
err := s.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) DnsIP() string {
|
func (s *DefaultServer) DnsIP() string {
|
||||||
if !s.enabled {
|
return s.service.RuntimeIP()
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return s.runtimeIP
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
|
||||||
ips := []string{defaultIP, customIP}
|
|
||||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
|
||||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
|
||||||
}
|
|
||||||
ports := []int{defaultPort, customPort}
|
|
||||||
for _, port := range ports {
|
|
||||||
for _, ip := range ips {
|
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err == nil {
|
|
||||||
err = probeListener.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
|
||||||
}
|
|
||||||
return ip, port, nil
|
|
||||||
}
|
|
||||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
|
||||||
s.listenerIsRunning = running
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the server
|
// Stop stops the server
|
||||||
@@ -194,34 +141,30 @@ func (s *DefaultServer) Stop() {
|
|||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
s.ctxCancel()
|
s.ctxCancel()
|
||||||
|
|
||||||
|
if s.hostManager != nil {
|
||||||
err := s.hostManager.restoreHostDNS()
|
err := s.hostManager.restoreHostDNS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
|
||||||
s.fakeResolverWG.Done()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.stopListener()
|
s.service.Stop()
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) stopListener() error {
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
if !s.listenerIsRunning {
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
return nil
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||||
}
|
s.hostsDnsListLock.Lock()
|
||||||
|
defer s.hostsDnsListLock.Unlock()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
s.hostsDnsList = hostsDnsList
|
||||||
defer cancel()
|
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||||
|
if ok {
|
||||||
err := s.server.ShutdownContext(ctx)
|
log.Debugf("on new host DNS config but skip to apply it")
|
||||||
if err != nil {
|
return
|
||||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
|
||||||
}
|
}
|
||||||
return nil
|
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||||
|
s.addHostRootZone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
// UpdateDNSServer processes an update received from the management service
|
||||||
@@ -272,16 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
// is the service should be disabled, we stop the listener or fake resolver
|
// is the service should be disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if !update.ServiceEnable {
|
if update.ServiceEnable {
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
_ = s.service.Listen()
|
||||||
s.fakeResolverWG.Done()
|
} else if !s.permanent {
|
||||||
} else {
|
s.service.Stop()
|
||||||
if err := s.stopListener(); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if !s.listenerIsRunning {
|
|
||||||
s.listen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
@@ -292,15 +229,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
s.updateMux(muxUpdates)
|
||||||
s.updateLocalResolver(localRecords)
|
s.updateLocalResolver(localRecords)
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
hostUpdate := s.currentConfig
|
||||||
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
|
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
||||||
hostUpdate.routeAll = false
|
hostUpdate.routeAll = false
|
||||||
@@ -405,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
|
|
||||||
|
var isContainRootUpdate bool
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.registerMux(update.domain, update.handler)
|
s.service.RegisterMux(update.domain, update.handler)
|
||||||
muxUpdateMap[update.domain] = update.handler
|
muxUpdateMap[update.domain] = update.handler
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if update.domain == nbdns.RootZone {
|
||||||
|
isContainRootUpdate = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
for key, existingHandler := range s.dnsMuxMap {
|
||||||
_, found := muxUpdateMap[key]
|
_, found := muxUpdateMap[key]
|
||||||
if !found {
|
if !found {
|
||||||
|
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||||
|
s.hostsDnsListLock.Lock()
|
||||||
|
s.addHostRootZone()
|
||||||
|
s.hostsDnsListLock.Unlock()
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
s.deregisterMux(key)
|
} else {
|
||||||
|
existingHandler.stop()
|
||||||
|
s.service.DeregisterMux(key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -448,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
|||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
|
||||||
s.dnsMux.HandleRemove(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
@@ -483,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
for i, item := range s.currentConfig.domains {
|
for i, item := range s.currentConfig.domains {
|
||||||
if _, found := removeIndex[item.domain]; found {
|
if _, found := removeIndex[item.domain]; found {
|
||||||
s.currentConfig.domains[i].disabled = true
|
s.currentConfig.domains[i].disabled = true
|
||||||
s.deregisterMux(item.domain)
|
s.service.DeregisterMux(item.domain)
|
||||||
removeIndex[item.domain] = i
|
removeIndex[item.domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -500,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.domains[i].disabled = false
|
s.currentConfig.domains[i].disabled = false
|
||||||
s.registerMux(domain, handler)
|
s.service.RegisterMux(domain, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@@ -516,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) filterDNSTraffic() string {
|
func (s *DefaultServer) addHostRootZone() {
|
||||||
filter := s.wgInterface.GetFilter()
|
handler := newUpstreamResolver(s.ctx)
|
||||||
if filter == nil {
|
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||||
log.Error("can't set DNS filter, filter not initialized")
|
for n, ua := range s.hostsDnsList {
|
||||||
return ""
|
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
|
||||||
}
|
}
|
||||||
|
handler.deactivate = func() {}
|
||||||
firstLayerDecoder := layers.LayerTypeIPv4
|
handler.reactivate = func() {}
|
||||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||||
firstLayerDecoder = layers.LayerTypeIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
hook := func(packetData []byte) bool {
|
|
||||||
// Decode the packet
|
|
||||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
|
||||||
|
|
||||||
// Get the UDP layer
|
|
||||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
||||||
udp := udpLayer.(*layers.UDP)
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(udp.Payload); err != nil {
|
|
||||||
log.Tracef("parse DNS request: %v", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := responseWriter{
|
|
||||||
packet: packet,
|
|
||||||
device: s.wgInterface.GetDevice().Device,
|
|
||||||
}
|
|
||||||
go s.dnsMux.ServeDNS(&writer, msg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) evalRuntimeAddress() {
|
|
||||||
defer func() {
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
|
||||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
|
||||||
s.runtimePort = defaultPort
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.customAddress != nil {
|
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, port, err := s.getFirstListenerAvailable()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtimeIP = ip
|
|
||||||
s.runtimePort = port
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
|
||||||
// Calculate the last IP in the CIDR range
|
|
||||||
var endIP net.IP
|
|
||||||
for i := 0; i < len(network.IP); i++ {
|
|
||||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert to big.Int
|
|
||||||
endInt := big.NewInt(0)
|
|
||||||
endInt.SetBytes(endIP)
|
|
||||||
|
|
||||||
// subtract fromEnd from the last ip
|
|
||||||
fromEndBig := big.NewInt(int64(fromEnd))
|
|
||||||
resultInt := big.NewInt(0)
|
|
||||||
resultInt.Sub(endInt, fromEndBig)
|
|
||||||
|
|
||||||
return net.IP(resultInt.Bytes()).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasValidDnsServer(cfg *nbdns.Config) bool {
|
|
||||||
for _, c := range cfg.NameServerGroups {
|
|
||||||
if c.Primary {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
29
client/internal/dns/server_export.go
Normal file
29
client/internal/dns/server_export.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mutex sync.Mutex
|
||||||
|
server Server
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
|
||||||
|
func GetServerDns() (Server, error) {
|
||||||
|
mutex.Lock()
|
||||||
|
if server == nil {
|
||||||
|
mutex.Unlock()
|
||||||
|
return nil, fmt.Errorf("DNS server not instantiated yet")
|
||||||
|
}
|
||||||
|
s := server
|
||||||
|
mutex.Unlock()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setServerDns(newServerServer Server) {
|
||||||
|
mutex.Lock()
|
||||||
|
server = newServerServer
|
||||||
|
defer mutex.Unlock()
|
||||||
|
}
|
||||||
24
client/internal/dns/server_export_test.go
Normal file
24
client/internal/dns/server_export_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetServerDns(t *testing.T) {
|
||||||
|
_, err := GetServerDns()
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("invalid dns server instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := &MockServer{}
|
||||||
|
setServerDns(srv)
|
||||||
|
|
||||||
|
srvB, err := GetServerDns()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("invalid dns server instance: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srvB != srv {
|
||||||
|
t.Errorf("missmatch dns instances")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,17 +5,59 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/golang/mock/gomock"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mocWGIface struct {
|
||||||
|
filter iface.PacketFilter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) Name() string {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) Address() iface.WGAddress {
|
||||||
|
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: ip,
|
||||||
|
Network: network,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||||
|
return w.filter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) IsUserspaceBind() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||||
|
w.filter = filter
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
{
|
{
|
||||||
Name: "peera.netbird.cloud",
|
Name: "peera.netbird.cloud",
|
||||||
@@ -26,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
log.SetLevel(log.TraceLevel)
|
||||||
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@@ -221,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -239,9 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
// pretend we are running
|
|
||||||
dnsServer.listenerIsRunning = true
|
|
||||||
dnsServer.fakeResolverWG.Add(1)
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -276,6 +320,133 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||||
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
|
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
|
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
newNet, err := stdnet.NewNet(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create stdnet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("build interface wireguard: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("crate and init wireguard interface: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = wgIface.Close(); err != nil {
|
||||||
|
t.Logf("close wireguard interface: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("parse CIDR: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
|
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||||
|
|
||||||
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
|
t.Errorf("set packet filter: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("run DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||||
|
t.Logf("restore DNS settings on the host: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
||||||
|
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||||
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
|
nameServers := []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the server with regular configuration
|
||||||
|
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update2 := update
|
||||||
|
update2.ServiceEnable = false
|
||||||
|
// Disable the server, stop the listener
|
||||||
|
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update3 := update2
|
||||||
|
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||||
|
// But service still get updates and we checking that we handle
|
||||||
|
// internal state in the right way
|
||||||
|
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSServerStartStop(t *testing.T) {
|
func TestDNSServerStartStop(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -292,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
||||||
|
if err != nil {
|
||||||
dnsServer.hostManager = newNoopHostMocker()
|
t.Fatalf("%v", err)
|
||||||
dnsServer.listen()
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
if !dnsServer.listenerIsRunning {
|
|
||||||
t.Fatal("dns server listener is not running")
|
|
||||||
}
|
}
|
||||||
|
dnsServer.hostManager = newNoopHostMocker()
|
||||||
|
err = dnsServer.service.Listen()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dns server is not running: %s", err)
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
defer dnsServer.Stop()
|
defer dnsServer.Stop()
|
||||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
||||||
|
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
@@ -314,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
d := net.Dialer{
|
d := net.Dialer{
|
||||||
Timeout: time.Second * 5,
|
Timeout: time.Second * 5,
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
conn, err := d.DialContext(ctx, network, addr)
|
conn, err := d.DialContext(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
@@ -349,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
hostManager := &mockHostConfigurator{}
|
hostManager := &mockHostConfigurator{}
|
||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
dnsMux: dns.DefaultServeMux,
|
service: newServiceViaMemory(&mocWGIface{}),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@@ -412,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||||
mux := dns.NewServeMux()
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
|
||||||
var parsedAddrPort *netip.AddrPort
|
|
||||||
if addrPort != "" {
|
|
||||||
parsed, err := netip.ParseAddrPort(addrPort)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal("failed to initialize wg interface")
|
||||||
}
|
|
||||||
parsedAddrPort = &parsed
|
|
||||||
}
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
var dnsList []string
|
||||||
Net: "udp",
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
|
||||||
Handler: mux,
|
err = dnsServer.Initialize()
|
||||||
UDPSize: 65535,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
|
||||||
|
|
||||||
ds := &DefaultServer{
|
|
||||||
ctx: ctx,
|
|
||||||
ctxCancel: cancel,
|
|
||||||
server: dnsServer,
|
|
||||||
dnsMux: mux,
|
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
localResolver: &localResolver{
|
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
customAddress: parsedAddrPort,
|
|
||||||
}
|
|
||||||
ds.evalRuntimeAddress()
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
addr string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
|
||||||
{"192.168.0.0/30", "192.168.0.2"},
|
|
||||||
{"192.168.0.0/16", "192.168.255.254"},
|
|
||||||
{"192.168.0.0/24", "192.168.0.254"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error parsing CIDR: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
||||||
if lastIP != tt.ip {
|
|
||||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||||
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to initialize wg interface")
|
||||||
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
|
// check initial state
|
||||||
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Enabled: true,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(1, update)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
update2 := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(2, update2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||||
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to initialize wg interface")
|
||||||
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
|
// check initial state
|
||||||
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Domains: []string{"customdomain.com"},
|
||||||
|
Primary: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(1, update)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||||
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
|
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
|
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
newNet, err := stdnet.NewNet(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create stdnet: %v", err)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build interface wireguard: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("crate and init wireguard interface: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pf, err := uspfilter.Create(wgIface)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.SetFilter(pf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("set packet filter: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgIface, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDnsResolver(ip string, port int) *net.Resolver {
|
||||||
|
return &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
d := net.Dialer{
|
||||||
|
Timeout: time.Second * 3,
|
||||||
|
}
|
||||||
|
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
return d.DialContext(ctx, network, addr)
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
18
client/internal/dns/service.go
Normal file
18
client/internal/dns/service.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPort = 53
|
||||||
|
)
|
||||||
|
|
||||||
|
type service interface {
|
||||||
|
Listen() error
|
||||||
|
Stop()
|
||||||
|
RegisterMux(domain string, handler dns.Handler)
|
||||||
|
DeregisterMux(key string)
|
||||||
|
RuntimePort() int
|
||||||
|
RuntimeIP() string
|
||||||
|
}
|
||||||
145
client/internal/dns/service_listener.go
Normal file
145
client/internal/dns/service_listener.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
customPort = 5053
|
||||||
|
defaultIP = "127.0.0.1"
|
||||||
|
customIP = "127.0.0.153"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serviceViaListener struct {
|
||||||
|
wgInterface WGIface
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
customAddr *netip.AddrPort
|
||||||
|
server *dns.Server
|
||||||
|
runtimeIP string
|
||||||
|
runtimePort int
|
||||||
|
listenerIsRunning bool
|
||||||
|
listenerFlagLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
s := &serviceViaListener{
|
||||||
|
wgInterface: wgIface,
|
||||||
|
dnsMux: mux,
|
||||||
|
customAddr: customAddr,
|
||||||
|
server: &dns.Server{
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
UDPSize: 65535,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) Listen() error {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to eval runtime address: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
err := s.server.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) Stop() {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := s.server.ShutdownContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) DeregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RuntimePort() int {
|
||||||
|
return s.runtimePort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RuntimeIP() string {
|
||||||
|
return s.runtimeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||||
|
s.listenerIsRunning = running
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
||||||
|
ips := []string{defaultIP, customIP}
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||||
|
}
|
||||||
|
ports := []int{defaultPort, customPort}
|
||||||
|
for _, port := range ports {
|
||||||
|
for _, ip := range ips {
|
||||||
|
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err == nil {
|
||||||
|
err = probeListener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
|
}
|
||||||
|
return ip, port, nil
|
||||||
|
}
|
||||||
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
|
||||||
|
if s.customAddr != nil {
|
||||||
|
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.getFirstListenerAvailable()
|
||||||
|
}
|
||||||
139
client/internal/dns/service_memory.go
Normal file
139
client/internal/dns/service_memory.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serviceViaMemory struct {
|
||||||
|
wgInterface WGIface
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
runtimeIP string
|
||||||
|
runtimePort int
|
||||||
|
udpFilterHookID string
|
||||||
|
listenerIsRunning bool
|
||||||
|
listenerFlagLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||||
|
s := &serviceViaMemory{
|
||||||
|
wgInterface: wgIface,
|
||||||
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
|
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
||||||
|
runtimePort: defaultPort,
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) Listen() error {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.listenerIsRunning = true
|
||||||
|
|
||||||
|
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) Stop() {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||||
|
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.listenerIsRunning = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RuntimePort() int {
|
||||||
|
return s.runtimePort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RuntimeIP() string {
|
||||||
|
return s.runtimeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||||
|
filter := s.wgInterface.GetFilter()
|
||||||
|
if filter == nil {
|
||||||
|
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
firstLayerDecoder := layers.LayerTypeIPv4
|
||||||
|
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||||
|
firstLayerDecoder = layers.LayerTypeIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
hook := func(packetData []byte) bool {
|
||||||
|
// Decode the packet
|
||||||
|
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||||
|
|
||||||
|
// Get the UDP layer
|
||||||
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||||
|
udp := udpLayer.(*layers.UDP)
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(udp.Payload); err != nil {
|
||||||
|
log.Tracef("parse DNS request: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := responseWriter{
|
||||||
|
packet: packet,
|
||||||
|
device: s.wgInterface.GetDevice().Device,
|
||||||
|
}
|
||||||
|
go s.dnsMux.ServeDNS(&writer, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||||
|
// Calculate the last IP in the CIDR range
|
||||||
|
var endIP net.IP
|
||||||
|
for i := 0; i < len(network.IP); i++ {
|
||||||
|
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to big.Int
|
||||||
|
endInt := big.NewInt(0)
|
||||||
|
endInt.SetBytes(endIP)
|
||||||
|
|
||||||
|
// subtract fromEnd from the last ip
|
||||||
|
fromEndBig := big.NewInt(int64(fromEnd))
|
||||||
|
resultInt := big.NewInt(0)
|
||||||
|
resultInt.Sub(endInt, fromEndBig)
|
||||||
|
|
||||||
|
return net.IP(resultInt.Bytes()).String()
|
||||||
|
}
|
||||||
31
client/internal/dns/service_memory_test.go
Normal file
31
client/internal/dns/service_memory_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
addr string
|
||||||
|
ip string
|
||||||
|
}{
|
||||||
|
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||||
|
{"192.168.0.0/30", "192.168.0.2"},
|
||||||
|
{"192.168.0.0/16", "192.168.255.254"},
|
||||||
|
{"192.168.0.0/24", "192.168.0.254"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error parsing CIDR: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||||
|
if lastIP != tt.ip {
|
||||||
|
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
|
|||||||
MatchOnly bool
|
MatchOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
14
client/internal/dns/wgiface.go
Normal file
14
client/internal/dns/wgiface.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/iface"
|
||||||
|
|
||||||
|
// WGIface defines subset methods of interface required for manager
|
||||||
|
type WGIface interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
GetFilter() iface.PacketFilter
|
||||||
|
GetDevice() *iface.DeviceWrapper
|
||||||
|
}
|
||||||
13
client/internal/dns/wgiface_windows.go
Normal file
13
client/internal/dns/wgiface_windows.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/iface"
|
||||||
|
|
||||||
|
// WGIface defines subset methods of interface required for manager
|
||||||
|
type WGIface interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
GetFilter() iface.PacketFilter
|
||||||
|
GetDevice() *iface.DeviceWrapper
|
||||||
|
GetInterfaceGUIDString() (string, error)
|
||||||
|
}
|
||||||
@@ -190,23 +190,25 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var routes []*route.Route
|
var routes []*route.Route
|
||||||
var dnsCfg *nbdns.Config
|
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
routes, dnsCfg, err = e.readInitialSettings()
|
routes, err = e.readInitialSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if e.dnsServer == nil {
|
if e.dnsServer == nil {
|
||||||
|
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
|
||||||
|
go e.mobileDep.DnsReadyListener.OnReady()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// todo fix custom address
|
// todo fix custom address
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg)
|
if e.dnsServer == nil {
|
||||||
|
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
||||||
@@ -605,6 +607,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
// cleanup request, most likely our peer has been deleted
|
// cleanup request, most likely our peer has been deleted
|
||||||
if networkMap.GetRemotePeersIsEmpty() {
|
if networkMap.GetRemotePeersIsEmpty() {
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -624,6 +627,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
|
||||||
// update SSHServer by adding remote peer SSH keys
|
// update SSHServer by adding remote peer SSH keys
|
||||||
if !isNil(e.sshServer) {
|
if !isNil(e.sshServer) {
|
||||||
for _, config := range networkMap.GetRemotePeers() {
|
for _, config := range networkMap.GetRemotePeers() {
|
||||||
@@ -759,17 +764,13 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
}
|
}
|
||||||
e.peerConns[peerKey] = conn
|
e.peerConns[peerKey] = conn
|
||||||
|
|
||||||
err = e.statusRecorder.AddPeer(peerKey)
|
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go e.connWorker(conn, peerKey)
|
go e.connWorker(conn, peerKey)
|
||||||
}
|
}
|
||||||
err := e.statusRecorder.UpdatePeerFQDN(peerKey, peerConfig.Fqdn)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerKey, err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1046,14 +1047,13 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, error) {
|
||||||
netMap, err := e.mgmClient.GetNetworkMap()
|
netMap, err := e.mgmClient.GetNetworkMap()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
routes := toRoutes(netMap.GetRoutes())
|
||||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
return routes, nil
|
||||||
return routes, &dnsCfg, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@@ -11,4 +12,6 @@ type MobileDependency struct {
|
|||||||
TunAdapter iface.TunAdapter
|
TunAdapter iface.TunAdapter
|
||||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
RouteListener routemanager.RouteListener
|
RouteListener routemanager.RouteListener
|
||||||
|
HostDNSAddresses []string
|
||||||
|
DnsReadyListener dns.ReadyListener
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,11 @@ type Status struct {
|
|||||||
mgmAddress string
|
mgmAddress string
|
||||||
signalAddress string
|
signalAddress string
|
||||||
notifier *notifier
|
notifier *notifier
|
||||||
|
|
||||||
|
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||||
|
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||||
|
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
|
||||||
|
peerListChangedForNotification bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRecorder returns a new Status instance
|
// NewRecorder returns a new Status instance
|
||||||
@@ -78,11 +83,13 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
d.offlinePeers = make([]State, len(replacement))
|
d.offlinePeers = make([]State, len(replacement))
|
||||||
copy(d.offlinePeers, replacement)
|
copy(d.offlinePeers, replacement)
|
||||||
d.notifyPeerListChanged()
|
|
||||||
|
// todo we should set to true in case if the list changed only
|
||||||
|
d.peerListChangedForNotification = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeer adds peer to Daemon status map
|
// AddPeer adds peer to Daemon status map
|
||||||
func (d *Status) AddPeer(peerPubKey string) error {
|
func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
@@ -90,7 +97,12 @@ func (d *Status) AddPeer(peerPubKey string) error {
|
|||||||
if ok {
|
if ok {
|
||||||
return errors.New("peer already exist")
|
return errors.New("peer already exist")
|
||||||
}
|
}
|
||||||
d.peers[peerPubKey] = State{PubKey: peerPubKey, ConnStatus: StatusDisconnected}
|
d.peers[peerPubKey] = State{
|
||||||
|
PubKey: peerPubKey,
|
||||||
|
ConnStatus: StatusDisconnected,
|
||||||
|
FQDN: fqdn,
|
||||||
|
}
|
||||||
|
d.peerListChangedForNotification = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,13 +124,13 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
_, ok := d.peers[peerPubKey]
|
_, ok := d.peers[peerPubKey]
|
||||||
if ok {
|
if !ok {
|
||||||
delete(d.peers, peerPubKey)
|
return errors.New("no peer with to remove")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
delete(d.peers, peerPubKey)
|
||||||
return errors.New("no peer with to remove")
|
d.peerListChangedForNotification = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerState updates peer status
|
// UpdatePeerState updates peer status
|
||||||
@@ -188,10 +200,23 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
|||||||
peerState.FQDN = fqdn
|
peerState.FQDN = fqdn
|
||||||
d.peers[peerPubKey] = peerState
|
d.peers[peerPubKey] = peerState
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FinishPeerListModifications this event invoke the notification
|
||||||
|
func (d *Status) FinishPeerListModifications() {
|
||||||
|
d.mux.Lock()
|
||||||
|
|
||||||
|
if !d.peerListChangedForNotification {
|
||||||
|
d.mux.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.peerListChangedForNotification = false
|
||||||
|
d.mux.Unlock()
|
||||||
|
|
||||||
|
d.notifyPeerListChanged()
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
||||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ import (
|
|||||||
func TestAddPeer(t *testing.T) {
|
func TestAddPeer(t *testing.T) {
|
||||||
key := "abc"
|
key := "abc"
|
||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
err := status.AddPeer(key)
|
err := status.AddPeer(key, "abc.netbird")
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
_, exists := status.peers[key]
|
_, exists := status.peers[key]
|
||||||
assert.True(t, exists, "value was found")
|
assert.True(t, exists, "value was found")
|
||||||
|
|
||||||
err = status.AddPeer(key)
|
err = status.AddPeer(key, "abc.netbird")
|
||||||
|
|
||||||
assert.Error(t, err, "should return error on duplicate")
|
assert.Error(t, err, "should return error on duplicate")
|
||||||
}
|
}
|
||||||
@@ -23,7 +23,7 @@ func TestAddPeer(t *testing.T) {
|
|||||||
func TestGetPeer(t *testing.T) {
|
func TestGetPeer(t *testing.T) {
|
||||||
key := "abc"
|
key := "abc"
|
||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
err := status.AddPeer(key)
|
err := status.AddPeer(key, "abc.netbird")
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
peerStatus, err := status.GetPeer(key)
|
peerStatus, err := status.GetPeer(key)
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/nftables"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,34 +28,14 @@ func genKey(format string, input string) string {
|
|||||||
|
|
||||||
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||||
func NewFirewall(parentCTX context.Context) firewallManager {
|
func NewFirewall(parentCTX context.Context) firewallManager {
|
||||||
ctx, cancel := context.WithCancel(parentCTX)
|
manager, err := newNFTablesManager(parentCTX)
|
||||||
|
if err == nil {
|
||||||
if isIptablesSupported() {
|
log.Debugf("nftables firewall manager will be used")
|
||||||
log.Debugf("iptables is supported")
|
|
||||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
|
|
||||||
return &iptablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
ipv4Client: ipv4Client,
|
|
||||||
ipv6Client: ipv6Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("iptables is not supported, using nftables")
|
|
||||||
|
|
||||||
manager := &nftablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
|
||||||
|
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
log.Debugf("fallback to iptables firewall manager: %s", err)
|
||||||
|
return newIptablesManager(parentCTX)
|
||||||
|
}
|
||||||
|
|
||||||
func getInPair(pair routerPair) routerPair {
|
func getInPair(pair routerPair) routerPair {
|
||||||
return routerPair{
|
return routerPair{
|
||||||
|
|||||||
@@ -49,6 +49,28 @@ type iptablesManager struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newIptablesManager(parentCtx context.Context) *iptablesManager {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if !isIptablesClientAvailable(ipv4Client) {
|
||||||
|
log.Infof("iptables is missing for ipv4")
|
||||||
|
ipv4Client = nil
|
||||||
|
}
|
||||||
|
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if !isIptablesClientAvailable(ipv6Client) {
|
||||||
|
log.Infof("iptables is missing for ipv6")
|
||||||
|
ipv6Client = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &iptablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
ipv4Client: ipv4Client,
|
||||||
|
ipv6Client: ipv6Client,
|
||||||
|
rules: make(map[string]map[string][]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
||||||
func (i *iptablesManager) CleanRoutingRules() {
|
func (i *iptablesManager) CleanRoutingRules() {
|
||||||
i.mux.Lock()
|
i.mux.Lock()
|
||||||
@@ -61,6 +83,7 @@ func (i *iptablesManager) CleanRoutingRules() {
|
|||||||
|
|
||||||
log.Debug("flushing tables")
|
log.Debug("flushing tables")
|
||||||
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
||||||
|
if i.ipv4Client != nil {
|
||||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||||
@@ -70,7 +93,9 @@ func (i *iptablesManager) CleanRoutingRules() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.ipv6Client != nil {
|
||||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||||
@@ -80,6 +105,7 @@ func (i *iptablesManager) CleanRoutingRules() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Info("done cleaning up iptables rules")
|
log.Info("done cleaning up iptables rules")
|
||||||
}
|
}
|
||||||
@@ -96,6 +122,7 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
|||||||
|
|
||||||
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
||||||
|
|
||||||
|
if i.ipv4Client != nil {
|
||||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||||
@@ -106,7 +133,14 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
|||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
err = i.restoreRules(i.ipv4Client)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.ipv6Client != nil {
|
||||||
|
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||||
}
|
}
|
||||||
@@ -116,17 +150,13 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
|||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.restoreRules(i.ipv4Client)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.restoreRules(i.ipv6Client)
|
err = i.restoreRules(i.ipv6Client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = i.addJumpRules()
|
err := i.addJumpRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
||||||
}
|
}
|
||||||
@@ -140,12 +170,13 @@ func (i *iptablesManager) addJumpRules() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if i.ipv4Client != nil {
|
||||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
||||||
|
|
||||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
i.rules[ipv4][ipv4Forwarding] = rule
|
i.rules[ipv4][ipv4Forwarding] = rule
|
||||||
|
|
||||||
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
||||||
@@ -154,8 +185,10 @@ func (i *iptablesManager) addJumpRules() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
i.rules[ipv4][ipv4Nat] = rule
|
i.rules[ipv4][ipv4Nat] = rule
|
||||||
|
}
|
||||||
|
|
||||||
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
if i.ipv6Client != nil {
|
||||||
|
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
||||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -168,6 +201,7 @@ func (i *iptablesManager) addJumpRules() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
i.rules[ipv6][ipv6Nat] = rule
|
i.rules[ipv6][ipv6Nat] = rule
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -177,6 +211,7 @@ func (i *iptablesManager) cleanJumpRules() error {
|
|||||||
var err error
|
var err error
|
||||||
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
||||||
rule, found := i.rules[ipv4][ipv4Forwarding]
|
rule, found := i.rules[ipv4][ipv4Forwarding]
|
||||||
|
if i.ipv4Client != nil {
|
||||||
if found {
|
if found {
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
||||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||||
@@ -192,6 +227,8 @@ func (i *iptablesManager) cleanJumpRules() error {
|
|||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if i.ipv6Client == nil {
|
||||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
rule, found = i.rules[ipv6][ipv6Forwarding]
|
||||||
if found {
|
if found {
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
||||||
@@ -208,6 +245,7 @@ func (i *iptablesManager) cleanJumpRules() error {
|
|||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -437,3 +475,8 @@ func getIptablesRuleType(table string) string {
|
|||||||
}
|
}
|
||||||
return ruleType
|
return ruleType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||||
|
_, err := client.ListChains("filter")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,17 +16,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
manager := newIptablesManager(context.TODO())
|
||||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
|
|
||||||
manager := &iptablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
ipv4Client: ipv4Client,
|
|
||||||
ipv6Client: ipv6Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
@@ -37,21 +27,21 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
|
|
||||||
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
||||||
|
|
||||||
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
require.True(t, exists, "forwarding rule should exist")
|
||||||
|
|
||||||
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting rule should exist")
|
||||||
|
|
||||||
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
||||||
|
|
||||||
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
require.True(t, exists, "forwarding rule should exist")
|
||||||
|
|
||||||
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting rule should exist")
|
||||||
|
|
||||||
@@ -64,13 +54,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
||||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
nat4RuleKey := genKey(natFormat, pair.ID)
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
pair = routerPair{
|
pair = routerPair{
|
||||||
@@ -83,13 +73,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
||||||
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
nat6RuleKey := genKey(natFormat, pair.ID)
|
||||||
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
||||||
|
|
||||||
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
delete(manager.rules, ipv4)
|
delete(manager.rules, ipv4)
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ const (
|
|||||||
nftablesTable = "netbird-rt"
|
nftablesTable = "netbird-rt"
|
||||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
||||||
nftablesRoutingNatChain = "netbird-rt-nat"
|
nftablesRoutingNatChain = "netbird-rt-nat"
|
||||||
|
|
||||||
|
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||||
|
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to create nftable rules
|
// constants needed to create nftable rules
|
||||||
@@ -78,9 +81,36 @@ type nftablesManager struct {
|
|||||||
tableIPv6 *nftables.Table
|
tableIPv6 *nftables.Table
|
||||||
chains map[string]map[string]*nftables.Chain
|
chains map[string]map[string]*nftables.Chain
|
||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
|
filterTable *nftables.Table
|
||||||
|
defaultForwardRules []*nftables.Rule
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
|
mgr := &nftablesManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
chains: make(map[string]map[string]*nftables.Chain),
|
||||||
|
rules: make(map[string]*nftables.Rule),
|
||||||
|
defaultForwardRules: make([]*nftables.Rule, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mgr.isSupported()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = mgr.readFilterTable()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing nftables rules from the system
|
// CleanRoutingRules cleans existing nftables rules from the system
|
||||||
func (n *nftablesManager) CleanRoutingRules() {
|
func (n *nftablesManager) CleanRoutingRules() {
|
||||||
n.mux.Lock()
|
n.mux.Lock()
|
||||||
@@ -90,6 +120,13 @@ func (n *nftablesManager) CleanRoutingRules() {
|
|||||||
n.conn.FlushTable(n.tableIPv6)
|
n.conn.FlushTable(n.tableIPv6)
|
||||||
n.conn.FlushTable(n.tableIPv4)
|
n.conn.FlushTable(n.tableIPv4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n.defaultForwardRules[0] != nil {
|
||||||
|
err := n.eraseDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete forward rule: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,6 +259,112 @@ func (n *nftablesManager) refreshRulesMap() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) readFilterTable() error {
|
||||||
|
tables, err := n.conn.ListTables()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.Name == "filter" {
|
||||||
|
n.filterTable = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
||||||
|
if n.defaultForwardRules[0] == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := n.refreshDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, r := range n.defaultForwardRules {
|
||||||
|
err = n.conn.DelRule(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete forward rule (%d): %s", i, err)
|
||||||
|
}
|
||||||
|
n.defaultForwardRules[i] = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) refreshDefaultForwardRule() error {
|
||||||
|
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to list rules in forward chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for i, r := range n.defaultForwardRules {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if string(rule.UserData) == string(r.UserData) {
|
||||||
|
n.defaultForwardRules[i] = rule
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("unable to find forward accept rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
|
||||||
|
src := generateCIDRMatcherExpressions("source", sourceNetwork)
|
||||||
|
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
|
||||||
|
|
||||||
|
var exprs []expr.Any
|
||||||
|
exprs = append(src, append(dst, &expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
})...)
|
||||||
|
|
||||||
|
r := &nftables.Rule{
|
||||||
|
Table: n.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: n.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||||
|
}
|
||||||
|
|
||||||
|
n.defaultForwardRules[0] = n.conn.AddRule(r)
|
||||||
|
|
||||||
|
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
|
||||||
|
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
|
||||||
|
|
||||||
|
exprs = append(src, append(dst, &expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
})...)
|
||||||
|
|
||||||
|
r = &nftables.Rule{
|
||||||
|
Table: n.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: n.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||||
|
}
|
||||||
|
|
||||||
|
n.defaultForwardRules[1] = n.conn.AddRule(r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
||||||
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
||||||
_, foundIPv4 := n.rules[ipv4Forwarding]
|
_, foundIPv4 := n.rules[ipv4Forwarding]
|
||||||
@@ -275,6 +418,14 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
|
||||||
|
err = n.acceptForwardRule(pair.source)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to create default forward rule: %s", err)
|
||||||
|
}
|
||||||
|
log.Debugf("default accept forward rule added")
|
||||||
|
}
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
||||||
@@ -355,6 +506,13 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
|
||||||
|
err := n.eraseDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delte default fwd rule: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
||||||
@@ -386,6 +544,14 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) isSupported() error {
|
||||||
|
_, err := n.conn.ListChains()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables is not supported: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// getPayloadDirectives get expression directives based on ip version and direction
|
// getPayloadDirectives get expression directives based on ip version and direction
|
||||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -14,21 +14,16 @@ import (
|
|||||||
|
|
||||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
manager, err := newNFTablesManager(context.TODO())
|
||||||
|
if err != nil {
|
||||||
manager := &nftablesManager{
|
t.Fatalf("failed to create nftables manager: %s", err)
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
err = manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||||
@@ -134,21 +129,16 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range insertRuleTestCases {
|
for _, testCase := range insertRuleTestCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
manager, err := newNFTablesManager(context.TODO())
|
||||||
|
if err != nil {
|
||||||
manager := &nftablesManager{
|
t.Fatalf("failed to create nftables manager: %s", err)
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
err = manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||||
@@ -239,21 +229,16 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range removeRuleTestCases {
|
for _, testCase := range removeRuleTestCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
manager, err := newNFTablesManager(context.TODO())
|
||||||
|
if err != nil {
|
||||||
manager := &nftablesManager{
|
t.Fatalf("failed to create nftables manager: %s", err)
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
err = manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
table := manager.tableIPv4
|
table := manager.tableIPv4
|
||||||
|
|||||||
82
client/internal/routemanager/systemops_bsd.go
Normal file
82
client/internal/routemanager/systemops_bsd.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||||
|
// +build darwin dragonfly freebsd netbsd openbsd
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
// selected BSD Route flags.
|
||||||
|
const (
|
||||||
|
RTF_UP = 0x1
|
||||||
|
RTF_GATEWAY = 0x2
|
||||||
|
RTF_HOST = 0x4
|
||||||
|
RTF_REJECT = 0x8
|
||||||
|
RTF_DYNAMIC = 0x10
|
||||||
|
RTF_MODIFIED = 0x20
|
||||||
|
RTF_STATIC = 0x800
|
||||||
|
RTF_BLACKHOLE = 0x1000
|
||||||
|
RTF_LOCAL = 0x200000
|
||||||
|
RTF_BROADCAST = 0x400000
|
||||||
|
RTF_MULTICAST = 0x800000
|
||||||
|
)
|
||||||
|
|
||||||
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||||
|
tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
m := msg.(*route.RouteMessage)
|
||||||
|
|
||||||
|
if m.Version < 3 || m.Version > 5 {
|
||||||
|
return false, fmt.Errorf("unexpected RIB message version: %d", m.Version)
|
||||||
|
}
|
||||||
|
if m.Type != 4 /* RTM_GET */ {
|
||||||
|
return true, fmt.Errorf("unexpected RIB message type: %d", m.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Flags&RTF_UP == 0 ||
|
||||||
|
m.Flags&(RTF_REJECT|RTF_BLACKHOLE) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst, err := toIPAddr(m.Addrs[0])
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("unexpected RIB destination: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mask, _ := toIPAddr(m.Addrs[2])
|
||||||
|
cidr, _ := net.IPMask(mask.To4()).Size()
|
||||||
|
if dst.String() == prefix.Addr().String() && cidr == prefix.Bits() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toIPAddr(a route.Addr) (net.IP, error) {
|
||||||
|
switch t := a.(type) {
|
||||||
|
case *route.Inet4Addr:
|
||||||
|
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
||||||
|
return ip, nil
|
||||||
|
case *route.Inet6Addr:
|
||||||
|
ip := make(net.IP, net.IPv6len)
|
||||||
|
copy(ip, t.IP[:])
|
||||||
|
return ip, nil
|
||||||
|
default:
|
||||||
|
return net.IP{}, fmt.Errorf("unknown family: %v", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,10 +6,28 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
|
||||||
|
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
|
||||||
|
type routeInfoInMemory struct {
|
||||||
|
Family byte
|
||||||
|
DstLen byte
|
||||||
|
SrcLen byte
|
||||||
|
TOS byte
|
||||||
|
|
||||||
|
Table byte
|
||||||
|
Protocol byte
|
||||||
|
Scope byte
|
||||||
|
Type byte
|
||||||
|
|
||||||
|
Flags uint32
|
||||||
|
}
|
||||||
|
|
||||||
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||||
|
|
||||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||||
@@ -61,6 +79,45 @@ func removeFromRouteTable(prefix netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||||
|
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
msgs, err := syscall.ParseNetlinkMessage(tab)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
loop:
|
||||||
|
for _, m := range msgs {
|
||||||
|
switch m.Header.Type {
|
||||||
|
case syscall.NLMSG_DONE:
|
||||||
|
break loop
|
||||||
|
case syscall.RTM_NEWROUTE:
|
||||||
|
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
|
||||||
|
attrs, err := syscall.ParseNetlinkRouteAttr(&m)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
if rt.Family != syscall.AF_INET {
|
||||||
|
continue loop
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, attr := range attrs {
|
||||||
|
if attr.Attr.Type == syscall.RTA_DST {
|
||||||
|
ip := net.IP(attr.Value)
|
||||||
|
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
|
||||||
|
cidr, _ := mask.Size()
|
||||||
|
if ip.String() == prefix.Addr().String() && cidr == prefix.Bits() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func enableIPForwarding() error {
|
||||||
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -14,19 +14,26 @@ import (
|
|||||||
var errRouteNotFound = fmt.Errorf("route not found")
|
var errRouteNotFound = fmt.Errorf("route not found")
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||||
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||||
if err != nil && err != errRouteNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
prefixGateway, err := getExistingRIBRouteGateway(prefix)
|
|
||||||
if err != nil && err != errRouteNotFound {
|
if err != nil && err != errRouteNotFound {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefixGateway != nil && !prefixGateway.Equal(gateway) {
|
gatewayIP := netip.MustParseAddr(defaultGateway.String())
|
||||||
log.Warnf("skipping adding a new route for network %s because it already exists and is pointing to the non default gateway: %s", prefix, prefixGateway)
|
if prefix.Contains(gatewayIP) {
|
||||||
|
log.Warnf("skipping adding a new route for network %s because it overlaps with the default gateway: %s", prefix, gatewayIP)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ok, err := existsInRouteTable(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
log.Warnf("skipping adding a new route for network %s because it already exists", prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return addToRouteTable(prefix, addr)
|
return addToRouteTable(prefix, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,6 +60,7 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
|
|||||||
log.Errorf("getting routes returned an error: %v", err)
|
log.Errorf("getting routes returned an error: %v", err)
|
||||||
return nil, errRouteNotFound
|
return nil, errRouteNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if gateway == nil {
|
if gateway == nil {
|
||||||
return preferredSrc, nil
|
return preferredSrc, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,19 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/pion/transport/v2/stdnet"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAddRemoveRoutes(t *testing.T) {
|
func TestAddRemoveRoutes(t *testing.T) {
|
||||||
@@ -114,3 +120,98 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
|
|||||||
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
|
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
||||||
|
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||||
|
fmt.Println("defaultGateway: ", defaultGateway)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
|
}
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
prefix netip.Prefix
|
||||||
|
preExistingPrefix netip.Prefix
|
||||||
|
shouldAddRoute bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should Add And Remove random Route",
|
||||||
|
prefix: netip.MustParsePrefix("99.99.99.99/32"),
|
||||||
|
shouldAddRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Not Add Route if overlaps with default gateway",
|
||||||
|
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
|
||||||
|
shouldAddRoute: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Add Route if bigger network exists",
|
||||||
|
prefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||||
|
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||||
|
shouldAddRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Add Route if smaller network exists",
|
||||||
|
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||||
|
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||||
|
shouldAddRoute: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Not Add Route if same network exists",
|
||||||
|
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||||
|
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||||
|
shouldAddRoute: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, testCase := range testCases {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
log.SetOutput(&buf)
|
||||||
|
defer func() {
|
||||||
|
log.SetOutput(os.Stderr)
|
||||||
|
}()
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
newNet, err := stdnet.NewNet()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet)
|
||||||
|
require.NoError(t, err, "should create testing WGIface interface")
|
||||||
|
defer wgInterface.Close()
|
||||||
|
|
||||||
|
err = wgInterface.Create()
|
||||||
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
|
||||||
|
MockAddr := wgInterface.Address().IP.String()
|
||||||
|
|
||||||
|
// Prepare the environment
|
||||||
|
if testCase.preExistingPrefix.IsValid() {
|
||||||
|
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr)
|
||||||
|
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the route
|
||||||
|
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr)
|
||||||
|
require.NoError(t, err, "should not return err when adding route")
|
||||||
|
|
||||||
|
if testCase.shouldAddRoute {
|
||||||
|
// test if route exists after adding
|
||||||
|
ok, err := existsInRouteTable(testCase.prefix)
|
||||||
|
require.NoError(t, err, "should not return err")
|
||||||
|
require.True(t, ok, "route should exist")
|
||||||
|
|
||||||
|
// remove route again if added
|
||||||
|
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr)
|
||||||
|
require.NoError(t, err, "should not return err")
|
||||||
|
}
|
||||||
|
|
||||||
|
// route should either not have been added or should have been removed
|
||||||
|
// In case of already existing route, it should not have been added (but still exist)
|
||||||
|
ok, err := existsInRouteTable(testCase.prefix)
|
||||||
|
fmt.Println("Buffer string: ", buf.String())
|
||||||
|
require.NoError(t, err, "should not return err")
|
||||||
|
if !strings.Contains(buf.String(), "because it already exists") {
|
||||||
|
require.False(t, ok, "route should not exist")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
37
client/internal/routemanager/systemops_windows.go
Normal file
37
client/internal/routemanager/systemops_windows.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/yusufpapurcu/wmi"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Win32_IP4RouteTable struct {
|
||||||
|
Destination string
|
||||||
|
Mask string
|
||||||
|
}
|
||||||
|
|
||||||
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||||
|
var routes []Win32_IP4RouteTable
|
||||||
|
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
||||||
|
|
||||||
|
err := wmi.Query(query, &routes)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
ip := net.ParseIP(route.Mask)
|
||||||
|
ip = ip.To4()
|
||||||
|
mask := net.IPv4Mask(ip[0], ip[1], ip[2], ip[3])
|
||||||
|
cidr, _ := mask.Size()
|
||||||
|
if route.Destination == prefix.Addr().String() && cidr == prefix.Bits() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
@@ -102,7 +102,7 @@ func (s *Server) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil {
|
if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
|
||||||
log.Errorf("init connections: %v", err)
|
log.Errorf("init connections: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -391,7 +391,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil {
|
if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
|
||||||
log.Errorf("run client connection: %v", err)
|
log.Errorf("run client connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,6 @@ package ssh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/creack/pty"
|
|
||||||
"github.com/gliderlabs/ssh"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@@ -13,11 +10,22 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/creack/pty"
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||||
const DefaultSSHPort = 44338
|
const DefaultSSHPort = 44338
|
||||||
|
|
||||||
|
// TerminalTimeout is the timeout for terminal session to be ready
|
||||||
|
const TerminalTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// TerminalBackoffDelay is the delay between terminal session readiness checks
|
||||||
|
const TerminalBackoffDelay = 500 * time.Millisecond
|
||||||
|
|
||||||
// DefaultSSHServer is a function that creates DefaultServer
|
// DefaultSSHServer is a function that creates DefaultServer
|
||||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||||
return newDefaultServer(hostKeyPEM, addr)
|
return newDefaultServer(hostKeyPEM, addr)
|
||||||
@@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
|
||||||
|
|
||||||
localUser, err := userNameLookup(session.User())
|
localUser, err := userNameLookup(session.User())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||||
@@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("Login command: %s", cmd.String())
|
||||||
file, err := pty.Start(cmd)
|
file, err := pty.Start(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed starting SSH server %v", err)
|
log.Errorf("failed starting SSH server %v", err)
|
||||||
@@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.Debugf("SSH session ended")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||||
@@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
|||||||
// stdin
|
// stdin
|
||||||
_, err := io.Copy(file, session)
|
_, err := io.Copy(file, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = session.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
|
||||||
|
timer := time.NewTimer(TerminalTimeout)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
|
||||||
|
_ = session.Exit(1)
|
||||||
|
return
|
||||||
|
default:
|
||||||
// stdout
|
// stdout
|
||||||
_, err := io.Copy(session, file)
|
writtenBytes, err := io.Copy(session, file)
|
||||||
if err != nil {
|
if err != nil && writtenBytes != 0 {
|
||||||
|
_ = session.Exit(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
time.Sleep(TerminalBackoffDelay)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts SSH server. Blocking
|
// Start starts SSH server. Blocking
|
||||||
|
|||||||
23
go.mod
23
go.mod
@@ -17,12 +17,12 @@ require (
|
|||||||
github.com/spf13/cobra v1.6.1
|
github.com/spf13/cobra v1.6.1
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/vishvananda/netlink v1.1.0
|
github.com/vishvananda/netlink v1.1.0
|
||||||
golang.org/x/crypto v0.7.0
|
golang.org/x/crypto v0.9.0
|
||||||
golang.org/x/sys v0.8.0
|
golang.org/x/sys v0.8.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
|
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/grpc v1.52.3
|
google.golang.org/grpc v1.55.0
|
||||||
google.golang.org/protobuf v1.30.0
|
google.golang.org/protobuf v1.30.0
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||||
)
|
)
|
||||||
@@ -48,6 +48,7 @@ require (
|
|||||||
github.com/mdlayher/socket v0.4.0
|
github.com/mdlayher/socket v0.4.0
|
||||||
github.com/miekg/dns v1.1.43
|
github.com/miekg/dns v1.1.43
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
github.com/pion/logging v0.2.2
|
github.com/pion/logging v0.2.2
|
||||||
@@ -57,6 +58,7 @@ require (
|
|||||||
github.com/rs/xid v1.3.0
|
github.com/rs/xid v1.3.0
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
github.com/stretchr/testify v1.8.1
|
github.com/stretchr/testify v1.8.1
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.3
|
||||||
go.opentelemetry.io/otel v1.11.1
|
go.opentelemetry.io/otel v1.11.1
|
||||||
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
|
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
|
||||||
go.opentelemetry.io/otel/metric v0.33.0
|
go.opentelemetry.io/otel/metric v0.33.0
|
||||||
@@ -65,18 +67,22 @@ require (
|
|||||||
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf
|
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf
|
||||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
|
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
|
||||||
golang.org/x/net v0.10.0
|
golang.org/x/net v0.10.0
|
||||||
golang.org/x/sync v0.1.0
|
golang.org/x/oauth2 v0.8.0
|
||||||
|
golang.org/x/sync v0.2.0
|
||||||
golang.org/x/term v0.8.0
|
golang.org/x/term v0.8.0
|
||||||
|
google.golang.org/api v0.126.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
cloud.google.com/go/compute v1.19.3 // indirect
|
||||||
|
cloud.google.com/go/compute/metadata v0.2.3 // indirect
|
||||||
github.com/BurntSushi/toml v1.2.1 // indirect
|
github.com/BurntSushi/toml v1.2.1 // indirect
|
||||||
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
|
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
|
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
@@ -92,9 +98,14 @@ require (
|
|||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
|
||||||
github.com/go-logr/logr v1.2.3 // indirect
|
github.com/go-logr/logr v1.2.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
||||||
github.com/go-stack/stack v1.8.0 // indirect
|
github.com/go-stack/stack v1.8.0 // indirect
|
||||||
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
||||||
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
|
github.com/google/s2a-go v0.1.4 // indirect
|
||||||
|
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
|
||||||
|
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
|
||||||
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/josharian/native v1.0.0 // indirect
|
github.com/josharian/native v1.0.0 // indirect
|
||||||
@@ -120,17 +131,17 @@ require (
|
|||||||
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
|
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
|
||||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
|
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
|
||||||
github.com/yuin/goldmark v1.4.13 // indirect
|
github.com/yuin/goldmark v1.4.13 // indirect
|
||||||
|
go.opencensus.io v0.24.0 // indirect
|
||||||
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
|
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.11.1 // indirect
|
go.opentelemetry.io/otel/trace v1.11.1 // indirect
|
||||||
golang.org/x/image v0.5.0 // indirect
|
golang.org/x/image v0.5.0 // indirect
|
||||||
golang.org/x/mod v0.8.0 // indirect
|
golang.org/x/mod v0.8.0 // indirect
|
||||||
golang.org/x/oauth2 v0.8.0 // indirect
|
|
||||||
golang.org/x/text v0.9.0 // indirect
|
golang.org/x/text v0.9.0 // indirect
|
||||||
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
|
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
|
||||||
golang.org/x/tools v0.6.0 // indirect
|
golang.org/x/tools v0.6.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
google.golang.org/appengine v1.6.7 // indirect
|
google.golang.org/appengine v1.6.7 // indirect
|
||||||
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||||
|
|||||||
48
go.sum
48
go.sum
@@ -20,7 +20,11 @@ cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvf
|
|||||||
cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg=
|
cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg=
|
||||||
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
|
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
|
||||||
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
|
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
|
||||||
|
cloud.google.com/go/compute v1.19.3 h1:DcTwsFgGev/wV5+q8o2fzgcHOaac+DKGC91ZlvpsQds=
|
||||||
|
cloud.google.com/go/compute v1.19.3/go.mod h1:qxvISKp/gYnXkSAD1ppcSOveRAmzxicEv/JlizULFrI=
|
||||||
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||||
|
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
|
||||||
|
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
|
||||||
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
|
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
|
||||||
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
|
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
|
||||||
cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
|
cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
|
||||||
@@ -87,8 +91,9 @@ github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInq
|
|||||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||||
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
|
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
|
||||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
|
||||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||||
|
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||||
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
||||||
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||||
@@ -101,6 +106,7 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX
|
|||||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||||
github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI=
|
github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI=
|
||||||
github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||||
|
github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||||
github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||||
github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||||
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||||
@@ -163,6 +169,7 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m
|
|||||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
|
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
|
||||||
|
github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0=
|
||||||
github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE=
|
github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE=
|
||||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||||
github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
|
github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
|
||||||
@@ -219,6 +226,7 @@ github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
|||||||
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
|
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
|
||||||
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||||
@@ -250,12 +258,14 @@ github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff/go.mod h1:wfqRWLHRBs
|
|||||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||||
github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ=
|
github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE=
|
||||||
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||||
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||||
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||||
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
|
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
|
||||||
@@ -321,13 +331,19 @@ github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hf
|
|||||||
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||||
|
github.com/google/s2a-go v0.1.4 h1:1kZ/sQM3srePvKs3tXAvQzo66XfcReoqFpIpIccE7Oc=
|
||||||
|
github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A=
|
||||||
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k=
|
||||||
|
github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k=
|
||||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||||
|
github.com/googleapis/gax-go/v2 v2.10.0 h1:ebSgKfMxynOdxw8QQuFOKMgomqeLGPqNLQox2bo42zg=
|
||||||
|
github.com/googleapis/gax-go/v2 v2.10.0/go.mod h1:4UOEnMCrxsSqQ940WnTiD6qJ63le2ev3xfyagutxiPw=
|
||||||
github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY=
|
github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY=
|
||||||
github.com/googleapis/gnostic v0.4.0/go.mod h1:on+2t9HRStVgn95RSsFWFz+6Q0Snyqv1awfrALZdbtU=
|
github.com/googleapis/gnostic v0.4.0/go.mod h1:on+2t9HRStVgn95RSsFWFz+6Q0Snyqv1awfrALZdbtU=
|
||||||
github.com/googleapis/gnostic v0.5.1/go.mod h1:6U4PtQXGIEt/Z3h5MAT7FNofLnw9vXk2cUuW7uA/OeU=
|
github.com/googleapis/gnostic v0.5.1/go.mod h1:6U4PtQXGIEt/Z3h5MAT7FNofLnw9vXk2cUuW7uA/OeU=
|
||||||
@@ -469,6 +485,8 @@ github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8m
|
|||||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
||||||
|
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
|
||||||
|
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=
|
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=
|
||||||
@@ -661,6 +679,8 @@ github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
|
|||||||
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
|
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||||
@@ -668,6 +688,8 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
|||||||
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||||
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||||
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
|
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
|
||||||
|
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||||
|
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||||
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
|
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
|
||||||
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
|
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
|
||||||
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 h1:xXhPj7SLKWU5/Zd4Hxmd+X1C4jdmvc0Xy+kvjFx2z60=
|
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 h1:xXhPj7SLKWU5/Zd4Hxmd+X1C4jdmvc0Xy+kvjFx2z60=
|
||||||
@@ -697,10 +719,11 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
|||||||
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
|
golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
|
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
|
||||||
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
||||||
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
|
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
||||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
@@ -831,8 +854,9 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
|
||||||
|
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -949,6 +973,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|||||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
|
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
@@ -1041,6 +1066,8 @@ google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0M
|
|||||||
google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
|
google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
|
||||||
google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
|
google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
|
||||||
google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc=
|
google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc=
|
||||||
|
google.golang.org/api v0.126.0 h1:q4GJq+cAdMAC7XP7njvQ4tvohGLiSlytuL4BQxbIZ+o=
|
||||||
|
google.golang.org/api v0.126.0/go.mod h1:mBwVAtz+87bEN6CbA1GtZPDOqY2R5ONPqJeIlvyo4Aw=
|
||||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
@@ -1082,8 +1109,10 @@ google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6D
|
|||||||
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||||
google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||||
google.golang.org/genproto v0.0.0-20210722135532-667f2b7c528f/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48=
|
google.golang.org/genproto v0.0.0-20210722135532-667f2b7c528f/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48=
|
||||||
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 h1:a2S6M0+660BgMNl++4JPlcAO/CjkqYItDEZwkoDQK7c=
|
google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc h1:8DyZCyvI8mE1IdLy/60bS+52xfymkE72wv1asokgtao=
|
||||||
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg=
|
google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc h1:kVKPf/IiYSBWEWtkIn6wZXwWGCnLKcC8oWfZvXjsGnM=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc h1:XSJ8Vk1SWuNr8S18z1NZSziL0CPIXLCCMDOEFtHBOFc=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA=
|
||||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||||
google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||||
@@ -1102,9 +1131,10 @@ google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTp
|
|||||||
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
||||||
google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
|
google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
|
||||||
google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE=
|
google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE=
|
||||||
|
google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ=
|
||||||
google.golang.org/grpc v1.51.0-dev/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
|
google.golang.org/grpc v1.51.0-dev/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
|
||||||
google.golang.org/grpc v1.52.3 h1:pf7sOysg4LdgBqduXveGKrcEwbStiK2rtfghdzlUYDQ=
|
google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag=
|
||||||
google.golang.org/grpc v1.52.3/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5vorUY=
|
google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8=
|
||||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v2"
|
"github.com/pion/transport/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
@@ -34,7 +33,6 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
|
|||||||
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
|
||||||
return w.tun.Create(mIFaceArgs)
|
return w.tun.Create(mIFaceArgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v2"
|
"github.com/pion/transport/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
@@ -38,6 +37,5 @@ func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
|||||||
func (w *WGIface) Create() error {
|
func (w *WGIface) Create() error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
|
||||||
return w.tun.Create()
|
return w.tun.Create()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
|
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
|
||||||
|
log.Info("create tun interface")
|
||||||
var err error
|
var err error
|
||||||
routesString := t.routesToString(mIFaceArgs.Routes)
|
routesString := t.routesToString(mIFaceArgs.Routes)
|
||||||
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)
|
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ import (
|
|||||||
|
|
||||||
func (c *tunDevice) Create() error {
|
func (c *tunDevice) Create() error {
|
||||||
if WireGuardModuleIsLoaded() {
|
if WireGuardModuleIsLoaded() {
|
||||||
log.Info("using kernel WireGuard")
|
log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
|
||||||
return c.createWithKernel()
|
return c.createWithKernel()
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tunModuleIsLoaded() {
|
if !tunModuleIsLoaded() {
|
||||||
return fmt.Errorf("couldn't check or load tun module")
|
return fmt.Errorf("couldn't check or load tun module")
|
||||||
}
|
}
|
||||||
log.Info("using userspace WireGuard")
|
log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
|
||||||
var err error
|
var err error
|
||||||
c.netInterface, err = c.createWithUserspace()
|
c.netInterface, err = c.createWithUserspace()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
|
|||||||
c.wrapper = newDeviceWrapper(tunIface)
|
c.wrapper = newDeviceWrapper(tunIface)
|
||||||
|
|
||||||
// We need to create a wireguard-go device and listen to configuration requests
|
// We need to create a wireguard-go device and listen to configuration requests
|
||||||
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
tunDev := device.NewDevice(c.wrapper, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
||||||
err = tunDev.Up()
|
err = tunDev.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
_ = tunIface.Close()
|
||||||
|
|||||||
@@ -97,16 +97,9 @@ curl "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" -q -o openid-configuration.js
|
|||||||
|
|
||||||
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' openid-configuration.json)
|
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' openid-configuration.json)
|
||||||
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' openid-configuration.json)
|
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' openid-configuration.json)
|
||||||
export NETBIRD_AUTH_SUPPORTED_SCOPES=$(jq -r '.scopes_supported | join(" ")' openid-configuration.json)
|
|
||||||
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' openid-configuration.json)
|
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' openid-configuration.json)
|
||||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' openid-configuration.json)
|
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' openid-configuration.json)
|
||||||
|
|
||||||
if [ "$NETBIRD_USE_AUTH0" == "true" ]; then
|
|
||||||
export NETBIRD_AUTH_SUPPORTED_SCOPES="openid profile email offline_access api email_verified"
|
|
||||||
else
|
|
||||||
export NETBIRD_AUTH_SUPPORTED_SCOPES="openid profile email offline_access api"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
||||||
# user enabled Device Authorization Grant feature
|
# user enabled Device Authorization Grant feature
|
||||||
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
||||||
|
|||||||
@@ -11,7 +11,10 @@ NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
|
|||||||
NETBIRD_AUTH_AUDIENCE=""
|
NETBIRD_AUTH_AUDIENCE=""
|
||||||
# e.g. netbird-client
|
# e.g. netbird-client
|
||||||
NETBIRD_AUTH_CLIENT_ID=""
|
NETBIRD_AUTH_CLIENT_ID=""
|
||||||
NETBIRD_AUTH_CLIENT_SECRET=""
|
# indicates the scopes that will be requested to the IDP
|
||||||
|
NETBIRD_AUTH_SUPPORTED_SCOPES=""
|
||||||
|
# NETBIRD_AUTH_CLIENT_SECRET is required only by Google workspace.
|
||||||
|
# NETBIRD_AUTH_CLIENT_SECRET=""
|
||||||
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
||||||
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
||||||
# indicates whether to use Auth0 or not: true or false
|
# indicates whether to use Auth0 or not: true or false
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ NETBIRD_DOMAIN=$CI_NETBIRD_DOMAIN
|
|||||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://example.eu.auth0.com/.well-known/openid-configuration"
|
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://example.eu.auth0.com/.well-known/openid-configuration"
|
||||||
# e.g. netbird-client
|
# e.g. netbird-client
|
||||||
NETBIRD_AUTH_CLIENT_ID=$CI_NETBIRD_AUTH_CLIENT_ID
|
NETBIRD_AUTH_CLIENT_ID=$CI_NETBIRD_AUTH_CLIENT_ID
|
||||||
|
NETBIRD_AUTH_SUPPORTED_SCOPES=$CI_NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||||
NETBIRD_AUTH_CLIENT_SECRET=$CI_NETBIRD_AUTH_CLIENT_SECRET
|
NETBIRD_AUTH_CLIENT_SECRET=$CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||||
# indicates whether to use Auth0 or not: true or false
|
# indicates whether to use Auth0 or not: true or false
|
||||||
NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
|
NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
|
||||||
|
|||||||
@@ -218,7 +218,11 @@ var (
|
|||||||
if !disableMetrics {
|
if !disableMetrics {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager)
|
idpManager := "disabled"
|
||||||
|
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
|
||||||
|
idpManager = config.IdpManagerConfig.ManagerType
|
||||||
|
}
|
||||||
|
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
|
||||||
go metricsWorker.Run()
|
go metricsWorker.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ const (
|
|||||||
PublicCategory = "public"
|
PublicCategory = "public"
|
||||||
PrivateCategory = "private"
|
PrivateCategory = "private"
|
||||||
UnknownCategory = "unknown"
|
UnknownCategory = "unknown"
|
||||||
|
GroupIssuedAPI = "api"
|
||||||
|
GroupIssuedJWT = "jwt"
|
||||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||||
@@ -51,6 +53,7 @@ type AccountManager interface {
|
|||||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||||
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||||
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
||||||
|
InviteUser(accountID string, initiatorUserID string, targetUserID string) error
|
||||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||||
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||||
@@ -78,7 +81,7 @@ type AccountManager interface {
|
|||||||
GetGroup(accountId, groupID string) (*Group, error)
|
GetGroup(accountId, groupID string) (*Group, error)
|
||||||
SaveGroup(accountID, userID string, group *Group) error
|
SaveGroup(accountID, userID string, group *Group) error
|
||||||
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
||||||
DeleteGroup(accountId, groupID string) error
|
DeleteGroup(accountId, userId, groupID string) error
|
||||||
ListGroups(accountId string) ([]*Group, error)
|
ListGroups(accountId string) ([]*Group, error)
|
||||||
GroupAddPeer(accountId, groupID, peerID string) error
|
GroupAddPeer(accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(accountId, groupID, peerKey string) error
|
GroupDeletePeer(accountId, groupID, peerKey string) error
|
||||||
@@ -139,6 +142,13 @@ type Settings struct {
|
|||||||
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
||||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||||
PeerLoginExpiration time.Duration
|
PeerLoginExpiration time.Duration
|
||||||
|
|
||||||
|
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
|
||||||
|
// and add it to account groups.
|
||||||
|
JWTGroupsEnabled bool
|
||||||
|
|
||||||
|
// JWTGroupsClaimName from which we extract groups name to add it to account groups
|
||||||
|
JWTGroupsClaimName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy copies the Settings struct
|
// Copy copies the Settings struct
|
||||||
@@ -146,6 +156,8 @@ func (s *Settings) Copy() *Settings {
|
|||||||
return &Settings{
|
return &Settings{
|
||||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||||
PeerLoginExpiration: s.PeerLoginExpiration,
|
PeerLoginExpiration: s.PeerLoginExpiration,
|
||||||
|
JWTGroupsEnabled: s.JWTGroupsEnabled,
|
||||||
|
JWTGroupsClaimName: s.JWTGroupsClaimName,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -612,6 +624,28 @@ func (a *Account) GetPeer(peerID string) *Peer {
|
|||||||
return a.Peers[peerID]
|
return a.Peers[peerID]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddJWTGroups to existed groups if they does not exists
|
||||||
|
func (a *Account) AddJWTGroups(groups []string) (int, error) {
|
||||||
|
existedGroups := make(map[string]*Group)
|
||||||
|
for _, g := range a.Groups {
|
||||||
|
existedGroups[g.Name] = g
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
for _, name := range groups {
|
||||||
|
if _, ok := existedGroups[name]; !ok {
|
||||||
|
id := xid.New().String()
|
||||||
|
a.Groups[id] = &Group{
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
Issued: GroupIssuedJWT,
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||||
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
||||||
@@ -1241,6 +1275,38 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.Settings.JWTGroupsEnabled {
|
||||||
|
if account.Settings.JWTGroupsClaimName == "" {
|
||||||
|
log.Errorf("JWT groups are enabled but no claim name is set")
|
||||||
|
return account, user, nil
|
||||||
|
}
|
||||||
|
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||||
|
if slice, ok := claim.([]interface{}); ok {
|
||||||
|
var groups []string
|
||||||
|
for _, item := range slice {
|
||||||
|
if g, ok := item.(string); ok {
|
||||||
|
groups = append(groups, g)
|
||||||
|
} else {
|
||||||
|
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n, err := account.AddJWTGroups(groups)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add JWT groups: %v", err)
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
if err := am.Store.SaveAccount(account); err != nil {
|
||||||
|
log.Errorf("failed to save account: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return account, user, nil
|
return account, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1346,6 +1412,7 @@ func addAllGroup(account *Account) error {
|
|||||||
allGroup := &Group{
|
allGroup := &Group{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
Name: "All",
|
Name: "All",
|
||||||
|
Issued: GroupIssuedAPI,
|
||||||
}
|
}
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||||
@@ -1373,33 +1440,28 @@ func addAllGroup(account *Account) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||||
func newAccountWithId(accountId, userId, domain string) *Account {
|
func newAccountWithId(accountID, userID, domain string) *Account {
|
||||||
log.Debugf("creating new account")
|
log.Debugf("creating new account")
|
||||||
|
|
||||||
setupKeys := make(map[string]*SetupKey)
|
|
||||||
defaultKey := GenerateDefaultSetupKey()
|
|
||||||
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{},
|
|
||||||
SetupKeyUnlimitedUsage)
|
|
||||||
setupKeys[defaultKey.Key] = defaultKey
|
|
||||||
setupKeys[oneOffKey.Key] = oneOffKey
|
|
||||||
network := NewNetwork()
|
network := NewNetwork()
|
||||||
peers := make(map[string]*Peer)
|
peers := make(map[string]*Peer)
|
||||||
users := make(map[string]*User)
|
users := make(map[string]*User)
|
||||||
routes := make(map[string]*route.Route)
|
routes := make(map[string]*route.Route)
|
||||||
|
setupKeys := map[string]*SetupKey{}
|
||||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||||
users[userId] = NewAdminUser(userId)
|
users[userID] = NewAdminUser(userID)
|
||||||
dnsSettings := &DNSSettings{
|
dnsSettings := &DNSSettings{
|
||||||
DisabledManagementGroups: make([]string, 0),
|
DisabledManagementGroups: make([]string, 0),
|
||||||
}
|
}
|
||||||
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
|
log.Debugf("created new account %s", accountID)
|
||||||
|
|
||||||
acc := &Account{
|
acc := &Account{
|
||||||
Id: accountId,
|
Id: accountID,
|
||||||
SetupKeys: setupKeys,
|
SetupKeys: setupKeys,
|
||||||
Network: network,
|
Network: network,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
Users: users,
|
Users: users,
|
||||||
CreatedBy: userId,
|
CreatedBy: userID,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
NameServerGroups: nameServersGroups,
|
NameServerGroups: nameServersGroups,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -53,7 +54,7 @@ func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy
|
|||||||
t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers))
|
t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(account.SetupKeys) != 2 {
|
if len(account.SetupKeys) != 0 {
|
||||||
t.Errorf("expected account to have len(SetupKeys) = %v, got %v", 2, len(account.SetupKeys))
|
t.Errorf("expected account to have len(SetupKeys) = %v, got %v", 2, len(account.SetupKeys))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -460,6 +461,75 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||||
|
userId := "user-id"
|
||||||
|
domain := "test.domain"
|
||||||
|
|
||||||
|
initAccount := newAccountWithId("", userId, domain)
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
|
accountID := initAccount.Id
|
||||||
|
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain)
|
||||||
|
require.NoError(t, err, "create init user failed")
|
||||||
|
// as initAccount was created without account id we have to take the id after account initialization
|
||||||
|
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
|
||||||
|
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||||
|
initAccount.Id = acc.Id
|
||||||
|
|
||||||
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
|
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||||
|
Domain: domain,
|
||||||
|
UserId: userId,
|
||||||
|
DomainCategory: "test-category",
|
||||||
|
Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||||
|
account, _, err := manager.GetAccountFromToken(claims)
|
||||||
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||||
|
initAccount.Settings.JWTGroupsEnabled = true
|
||||||
|
err := manager.Store.SaveAccount(initAccount)
|
||||||
|
require.NoError(t, err, "save account failed")
|
||||||
|
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
|
||||||
|
|
||||||
|
account, _, err := manager.GetAccountFromToken(claims)
|
||||||
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||||
|
initAccount.Settings.JWTGroupsEnabled = true
|
||||||
|
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||||
|
err := manager.Store.SaveAccount(initAccount)
|
||||||
|
require.NoError(t, err, "save account failed")
|
||||||
|
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
|
||||||
|
|
||||||
|
account, _, err := manager.GetAccountFromToken(claims)
|
||||||
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||||
|
|
||||||
|
groupsByNames := map[string]*Group{}
|
||||||
|
for _, g := range account.Groups {
|
||||||
|
groupsByNames[g.Name] = g
|
||||||
|
}
|
||||||
|
|
||||||
|
g1, ok := groupsByNames["group1"]
|
||||||
|
require.True(t, ok, "group1 should be added to the account")
|
||||||
|
require.Equal(t, g1.Name, "group1", "group1 name should match")
|
||||||
|
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match")
|
||||||
|
|
||||||
|
g2, ok := groupsByNames["group2"]
|
||||||
|
require.True(t, ok, "group2 should be added to the account")
|
||||||
|
require.Equal(t, g2.Name, "group2", "group2 name should match")
|
||||||
|
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId("account_id", "testuser", "")
|
account := newAccountWithId("account_id", "testuser", "")
|
||||||
@@ -704,20 +774,21 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := createAccount(manager, "test_account", "account_creator", "netbird.cloud")
|
userID := "account_creator"
|
||||||
|
account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
serial := account.Network.CurrentSerial() // should be 0
|
serial := account.Network.CurrentSerial() // should be 0
|
||||||
|
|
||||||
var setupKey *SetupKey
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||||
for _, key := range account.SetupKeys {
|
if err != nil {
|
||||||
setupKey = key
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if setupKey == nil {
|
if err != nil {
|
||||||
t.Errorf("expecting account to have a default setup key")
|
t.Fatal("error creating setup key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -858,16 +929,13 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var setupKey *SetupKey
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||||
for _, key := range account.SetupKeys {
|
if err != nil {
|
||||||
setupKey = key
|
return
|
||||||
if setupKey.Type == SetupKeyReusable {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if setupKey == nil {
|
if err != nil {
|
||||||
t.Errorf("expecting account to have a default setup key")
|
t.Fatal("error creating setup key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1021,7 +1089,10 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := manager.DeleteGroup(account.Id, group.ID); err != nil {
|
// clean policy is pre requirement for delete group
|
||||||
|
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
|
||||||
|
|
||||||
|
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
|
||||||
t.Errorf("delete group: %v", err)
|
t.Errorf("delete group: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1042,9 +1113,14 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var setupKey *SetupKey
|
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||||
for _, key := range account.SetupKeys {
|
if err != nil {
|
||||||
setupKey = key
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating setup key")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ const (
|
|||||||
UserBlocked
|
UserBlocked
|
||||||
// UserUnblocked indicates that a user unblocked another user
|
// UserUnblocked indicates that a user unblocked another user
|
||||||
UserUnblocked
|
UserUnblocked
|
||||||
|
// GroupDeleted indicates that a user deleted group
|
||||||
|
GroupDeleted
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -192,6 +194,8 @@ const (
|
|||||||
UserBlockedMessage string = "User blocked"
|
UserBlockedMessage string = "User blocked"
|
||||||
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
||||||
UserUnblockedMessage string = "User unblocked"
|
UserUnblockedMessage string = "User unblocked"
|
||||||
|
// GroupDeletedMessage is a human-readable text message of the GroupDeleted activity
|
||||||
|
GroupDeletedMessage string = "Group deleted"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Activity that triggered an Event
|
// Activity that triggered an Event
|
||||||
@@ -294,6 +298,8 @@ func (a Activity) Message() string {
|
|||||||
return UserBlockedMessage
|
return UserBlockedMessage
|
||||||
case UserUnblocked:
|
case UserUnblocked:
|
||||||
return UserUnblockedMessage
|
return UserUnblockedMessage
|
||||||
|
case GroupDeleted:
|
||||||
|
return GroupDeletedMessage
|
||||||
default:
|
default:
|
||||||
return "UNKNOWN_ACTIVITY"
|
return "UNKNOWN_ACTIVITY"
|
||||||
}
|
}
|
||||||
@@ -342,6 +348,8 @@ func (a Activity) StringCode() string {
|
|||||||
return "group.add"
|
return "group.add"
|
||||||
case GroupUpdated:
|
case GroupUpdated:
|
||||||
return "group.update"
|
return "group.update"
|
||||||
|
case GroupDeleted:
|
||||||
|
return "group.delete"
|
||||||
case GroupRemovedFromPeer:
|
case GroupRemovedFromPeer:
|
||||||
return "peer.group.delete"
|
return "peer.group.delete"
|
||||||
case GroupAddedToPeer:
|
case GroupAddedToPeer:
|
||||||
|
|||||||
@@ -2,13 +2,15 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultTTL = 300
|
const defaultTTL = 300
|
||||||
@@ -199,15 +201,27 @@ func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
|||||||
for _, gID := range nsGroup.Groups {
|
for _, gID := range nsGroup.Groups {
|
||||||
_, found := groupList[gID]
|
_, found := groupList[gID]
|
||||||
if found {
|
if found {
|
||||||
|
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return peerNSGroups
|
return peerNSGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||||
|
func peerIsNameserver(peer *Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
|
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
label, err := getPeerHostLabel(peer.Name, peerLabels)
|
label, err := getPeerHostLabel(peer.Name, peerLabels)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
@@ -17,6 +19,7 @@ const (
|
|||||||
dnsAccountID = "testingAcc"
|
dnsAccountID = "testingAcc"
|
||||||
dnsAdminUserID = "testingAdminUser"
|
dnsAdminUserID = "testingAdminUser"
|
||||||
dnsRegularUserID = "testingRegularUser"
|
dnsRegularUserID = "testingRegularUser"
|
||||||
|
dnsNSGroup1 = "ns1"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetDNSSettings(t *testing.T) {
|
func TestGetDNSSettings(t *testing.T) {
|
||||||
@@ -163,6 +166,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
||||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
||||||
|
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
||||||
|
|
||||||
dnsSettings := account.DNSSettings.Copy()
|
dnsSettings := account.DNSSettings.Copy()
|
||||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||||
@@ -174,11 +178,11 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
||||||
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
||||||
|
|
||||||
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
||||||
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
||||||
|
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||||
@@ -246,7 +250,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err = am.AddPeer("", dnsAdminUserID, peer1)
|
savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -284,6 +288,24 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
|||||||
account.Groups[newGroup1.ID] = newGroup1
|
account.Groups[newGroup1.ID] = newGroup1
|
||||||
account.Groups[newGroup2.ID] = newGroup2
|
account.Groups[newGroup2.ID] = newGroup2
|
||||||
|
|
||||||
|
allGroup, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{
|
||||||
|
ID: dnsNSGroup1,
|
||||||
|
Name: "ns-group-1",
|
||||||
|
NameServers: []dns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr(savedPeer1.IP.String()),
|
||||||
|
NSType: dns.UDPNameServerType,
|
||||||
|
Port: dns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
Primary: true,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{allGroup.ID},
|
||||||
|
}
|
||||||
|
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -157,6 +157,14 @@ func restore(file string) (*FileStore, error) {
|
|||||||
addPeerLabelsToAccount(account, existingLabels)
|
addPeerLabelsToAccount(account, existingLabels)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: delete this block after migration
|
||||||
|
// Set API as issuer for groups which has not this field
|
||||||
|
for _, group := range account.Groups {
|
||||||
|
if group.Issued == "" {
|
||||||
|
group.Issued = GroupIssuedAPI
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
allGroup, err := account.GetGroupAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||||
@@ -281,6 +289,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
|||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
if account.Id == "" {
|
||||||
|
return status.Errorf(status.InvalidArgument, "account id should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
accountCopy := account.Copy()
|
accountCopy := account.Copy()
|
||||||
|
|
||||||
s.Accounts[accountCopy.Id] = accountCopy
|
s.Accounts[accountCopy.Id] = accountCopy
|
||||||
@@ -326,7 +338,7 @@ func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
|||||||
|
|
||||||
delete(s.HashedPAT2TokenID, hashedToken)
|
delete(s.HashedPAT2TokenID, hashedToken)
|
||||||
|
|
||||||
return s.persist(s.storeFile)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
|
// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
|
||||||
@@ -336,7 +348,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
|||||||
|
|
||||||
delete(s.TokenID2UserID, tokenID)
|
delete(s.TokenID2UserID, tokenID)
|
||||||
|
|
||||||
return s.persist(s.storeFile)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByPrivateDomain returns account by private domain
|
// GetAccountByPrivateDomain returns account by private domain
|
||||||
|
|||||||
@@ -262,6 +262,7 @@ func TestRestore(t *testing.T) {
|
|||||||
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: outdated, delete this
|
||||||
func TestRestorePolicies_Migration(t *testing.T) {
|
func TestRestorePolicies_Migration(t *testing.T) {
|
||||||
storeDir := t.TempDir()
|
storeDir := t.TempDir()
|
||||||
|
|
||||||
@@ -296,6 +297,40 @@ func TestRestorePolicies_Migration(t *testing.T) {
|
|||||||
"failed to restore a FileStore file - missing Account Policies Sources")
|
"failed to restore a FileStore file - missing Account Policies Sources")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRestoreGroups_Migration(t *testing.T) {
|
||||||
|
storeDir := t.TempDir()
|
||||||
|
|
||||||
|
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := NewFileStore(storeDir, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create default group
|
||||||
|
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||||
|
account.Groups = map[string]*Group{
|
||||||
|
"cfefqs706sqkneg59g3g": {
|
||||||
|
ID: "cfefqs706sqkneg59g3g",
|
||||||
|
Name: "All",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = store.SaveAccount(account)
|
||||||
|
require.NoError(t, err, "failed to save account")
|
||||||
|
|
||||||
|
// restore account with default group with empty Issue field
|
||||||
|
if store, err = NewFileStore(storeDir, nil); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||||
|
|
||||||
|
require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
|
||||||
|
require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetAccountByPrivateDomain(t *testing.T) {
|
func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||||
storeDir := t.TempDir()
|
storeDir := t.TempDir()
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,23 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type GroupLinkError struct {
|
||||||
|
Resource string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *GroupLinkError) Error() string {
|
||||||
|
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
||||||
|
}
|
||||||
|
|
||||||
// Group of the peers for ACL
|
// Group of the peers for ACL
|
||||||
type Group struct {
|
type Group struct {
|
||||||
// ID of the group
|
// ID of the group
|
||||||
@@ -14,6 +26,9 @@ type Group struct {
|
|||||||
// Name visible in the UI
|
// Name visible in the UI
|
||||||
Name string
|
Name string
|
||||||
|
|
||||||
|
// Issued of the group
|
||||||
|
Issued string
|
||||||
|
|
||||||
// Peers list of the group
|
// Peers list of the group
|
||||||
Peers []string
|
Peers []string
|
||||||
}
|
}
|
||||||
@@ -47,6 +62,7 @@ func (g *Group) Copy() *Group {
|
|||||||
return &Group{
|
return &Group{
|
||||||
ID: g.ID,
|
ID: g.ID,
|
||||||
Name: g.Name,
|
Name: g.Name,
|
||||||
|
Issued: g.Issued,
|
||||||
Peers: g.Peers[:],
|
Peers: g.Peers[:],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,15 +215,80 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup object of the peers
|
// DeleteGroup object of the peers
|
||||||
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountId)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
g, ok := account.Groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check route links
|
||||||
|
for _, r := range account.Routes {
|
||||||
|
for _, g := range r.Groups {
|
||||||
|
if g == groupID {
|
||||||
|
return &GroupLinkError{"route", r.NetID}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check DNS links
|
||||||
|
for _, dns := range account.NameServerGroups {
|
||||||
|
for _, g := range dns.Groups {
|
||||||
|
if g == groupID {
|
||||||
|
return &GroupLinkError{"name server groups", dns.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check ACL links
|
||||||
|
for _, policy := range account.Policies {
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
for _, src := range rule.Sources {
|
||||||
|
if src == groupID {
|
||||||
|
return &GroupLinkError{"policy", policy.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dst := range rule.Destinations {
|
||||||
|
if dst == groupID {
|
||||||
|
return &GroupLinkError{"policy", policy.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check setup key links
|
||||||
|
for _, setupKey := range account.SetupKeys {
|
||||||
|
for _, grp := range setupKey.AutoGroups {
|
||||||
|
if grp == groupID {
|
||||||
|
return &GroupLinkError{"setup key", setupKey.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check user links
|
||||||
|
for _, user := range account.Users {
|
||||||
|
for _, grp := range user.AutoGroups {
|
||||||
|
if grp == groupID {
|
||||||
|
return &GroupLinkError{"user", user.Id}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check DisabledManagementGroups
|
||||||
|
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
|
||||||
|
if disabledMgmGrp == groupID {
|
||||||
|
return &GroupLinkError{"disabled DNS management groups", g.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
delete(account.Groups, groupID)
|
delete(account.Groups, groupID)
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
@@ -215,6 +296,8 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
return am.updateAccountPeers(account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
164
management/server/group_test.go
Normal file
164
management/server/group_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
groupAdminUserID = "testingAdminUser"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||||
|
am, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to create account manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := initTestGroupAccount(am)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to init testing account")
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
groupID string
|
||||||
|
expectedReason string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"route",
|
||||||
|
"grp-for-route",
|
||||||
|
"route",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name server groups",
|
||||||
|
"grp-for-name-server-grp",
|
||||||
|
"name server groups",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"policy",
|
||||||
|
"grp-for-policies",
|
||||||
|
"policy",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup keys",
|
||||||
|
"grp-for-keys",
|
||||||
|
"setup key",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"users",
|
||||||
|
"grp-for-users",
|
||||||
|
"user",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
err = am.DeleteGroup(account.Id, "", testCase.groupID)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("delete %s group successfully", testCase.groupID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gErr, ok := err.(*GroupLinkError)
|
||||||
|
if !ok {
|
||||||
|
t.Error("invalid error type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gErr.Resource != testCase.expectedReason {
|
||||||
|
t.Errorf("invalid error case: %s, expected: %s", gErr.Resource, testCase.expectedReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||||
|
accountID := "testingAcc"
|
||||||
|
domain := "example.com"
|
||||||
|
|
||||||
|
groupForRoute := &Group{
|
||||||
|
"grp-for-route",
|
||||||
|
"Group for route",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupForNameServerGroups := &Group{
|
||||||
|
"grp-for-name-server-grp",
|
||||||
|
"Group for name server groups",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupForPolicies := &Group{
|
||||||
|
"grp-for-policies",
|
||||||
|
"Group for policies",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupForSetupKeys := &Group{
|
||||||
|
"grp-for-keys",
|
||||||
|
"Group for setup keys",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupForUsers := &Group{
|
||||||
|
"grp-for-users",
|
||||||
|
"Group for users",
|
||||||
|
GroupIssuedAPI,
|
||||||
|
make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
routeResource := &route.Route{
|
||||||
|
ID: "example route",
|
||||||
|
Groups: []string{groupForRoute.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
nameServerGroup := &nbdns.NameServerGroup{
|
||||||
|
ID: "example name server group",
|
||||||
|
Groups: []string{groupForNameServerGroups.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := &Policy{
|
||||||
|
ID: "example policy",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "example policy rule",
|
||||||
|
Destinations: []string{groupForPolicies.ID},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey := &SetupKey{
|
||||||
|
Id: "example setup key",
|
||||||
|
AutoGroups: []string{groupForSetupKeys.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &User{
|
||||||
|
Id: "example user",
|
||||||
|
AutoGroups: []string{groupForUsers.ID},
|
||||||
|
}
|
||||||
|
account := newAccountWithId(accountID, groupAdminUserID, domain)
|
||||||
|
account.Routes[routeResource.ID] = routeResource
|
||||||
|
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||||
|
account.Policies = append(account.Policies, policy)
|
||||||
|
account.SetupKeys[setupKey.Id] = setupKey
|
||||||
|
account.Users[user.Id] = user
|
||||||
|
|
||||||
|
err := am.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
|
||||||
|
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
|
||||||
|
|
||||||
|
return am.Store.GetAccount(account.Id)
|
||||||
|
}
|
||||||
@@ -72,10 +72,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, &server.Settings{
|
settings := &server.Settings{
|
||||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if req.Settings.JwtGroupsEnabled != nil {
|
||||||
|
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
||||||
|
}
|
||||||
|
if req.Settings.JwtGroupsClaimName != nil {
|
||||||
|
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
@@ -93,6 +102,8 @@ func toAccountResponse(account *server.Account) *api.Account {
|
|||||||
Settings: api.AccountSettings{
|
Settings: api.AccountSettings{
|
||||||
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
||||||
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
||||||
|
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
|
||||||
|
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
accountID := "test_account"
|
accountID := "test_account"
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user")
|
||||||
|
|
||||||
|
sr := func(v string) *string { return &v }
|
||||||
|
br := func(v bool) *bool { return &v }
|
||||||
|
|
||||||
handler := initAccountsTestData(&server.Account{
|
handler := initAccountsTestData(&server.Account{
|
||||||
Id: accountID,
|
Id: accountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
@@ -91,6 +94,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedSettings: api.AccountSettings{
|
expectedSettings: api.AccountSettings{
|
||||||
PeerLoginExpiration: int(time.Hour.Seconds()),
|
PeerLoginExpiration: int(time.Hour.Seconds()),
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
|
JwtGroupsClaimName: sr(""),
|
||||||
|
JwtGroupsEnabled: br(false),
|
||||||
},
|
},
|
||||||
expectedArray: true,
|
expectedArray: true,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -105,6 +110,24 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedSettings: api.AccountSettings{
|
expectedSettings: api.AccountSettings{
|
||||||
PeerLoginExpiration: 15552000,
|
PeerLoginExpiration: 15552000,
|
||||||
PeerLoginExpirationEnabled: true,
|
PeerLoginExpirationEnabled: true,
|
||||||
|
JwtGroupsClaimName: sr(""),
|
||||||
|
JwtGroupsEnabled: br(false),
|
||||||
|
},
|
||||||
|
expectedArray: false,
|
||||||
|
expectedID: accountID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PutAccount OK wiht JWT",
|
||||||
|
expectedBody: true,
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/accounts/" + accountID,
|
||||||
|
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedSettings: api.AccountSettings{
|
||||||
|
PeerLoginExpiration: 15552000,
|
||||||
|
PeerLoginExpirationEnabled: false,
|
||||||
|
JwtGroupsClaimName: sr("roles"),
|
||||||
|
JwtGroupsEnabled: br(true),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ components:
|
|||||||
description: Period of time after which peer login expires (seconds).
|
description: Period of time after which peer login expires (seconds).
|
||||||
type: integer
|
type: integer
|
||||||
example: 43200
|
example: 43200
|
||||||
|
jwt_groups_enabled:
|
||||||
|
description: Allows extract groups from JWT claim and add it to account groups.
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
jwt_groups_claim_name:
|
||||||
|
description: Name of the claim from which we extract groups names to add it to account groups.
|
||||||
|
type: string
|
||||||
|
example: "roles"
|
||||||
required:
|
required:
|
||||||
- peer_login_expiration_enabled
|
- peer_login_expiration_enabled
|
||||||
- peer_login_expiration
|
- peer_login_expiration
|
||||||
@@ -462,6 +470,10 @@ components:
|
|||||||
description: Count of peers associated to the group
|
description: Count of peers associated to the group
|
||||||
type: integer
|
type: integer
|
||||||
example: 2
|
example: 2
|
||||||
|
issued:
|
||||||
|
description: How group was issued by API or from JWT token
|
||||||
|
type: string
|
||||||
|
example: api
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- name
|
- name
|
||||||
@@ -1262,6 +1274,33 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
/api/users/{userId}/invite:
|
||||||
|
post:
|
||||||
|
summary: Resend user invitation
|
||||||
|
description: Resend user invitation
|
||||||
|
tags: [ Users ]
|
||||||
|
security:
|
||||||
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: userId
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: The unique identifier of a user
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Invite status code
|
||||||
|
content: {}
|
||||||
|
'400':
|
||||||
|
"$ref": "#/components/responses/bad_request"
|
||||||
|
'401':
|
||||||
|
"$ref": "#/components/responses/requires_authentication"
|
||||||
|
'403':
|
||||||
|
"$ref": "#/components/responses/forbidden"
|
||||||
|
'500':
|
||||||
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/peers:
|
/api/peers:
|
||||||
get:
|
get:
|
||||||
summary: List all Peers
|
summary: List all Peers
|
||||||
|
|||||||
@@ -129,6 +129,12 @@ type AccountRequest struct {
|
|||||||
|
|
||||||
// AccountSettings defines model for AccountSettings.
|
// AccountSettings defines model for AccountSettings.
|
||||||
type AccountSettings struct {
|
type AccountSettings struct {
|
||||||
|
// JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups.
|
||||||
|
JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"`
|
||||||
|
|
||||||
|
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
||||||
|
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
||||||
|
|
||||||
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
||||||
PeerLoginExpiration int `json:"peer_login_expiration"`
|
PeerLoginExpiration int `json:"peer_login_expiration"`
|
||||||
|
|
||||||
@@ -174,6 +180,9 @@ type Group struct {
|
|||||||
// Id Group ID
|
// Id Group ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// Issued How group was issued by API or from JWT token
|
||||||
|
Issued *string `json:"issued,omitempty"`
|
||||||
|
|
||||||
// Name Group Name identifier
|
// Name Group Name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
@@ -189,6 +198,9 @@ type GroupMinimum struct {
|
|||||||
// Id Group ID
|
// Id Group ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// Issued How group was issued by API or from JWT token
|
||||||
|
Issued *string `json:"issued,omitempty"`
|
||||||
|
|
||||||
// Name Group Name identifier
|
// Name Group Name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok = account.Groups[groupID]
|
eg, ok := account.Groups[groupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||||
return
|
return
|
||||||
@@ -110,6 +110,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
ID: groupID,
|
ID: groupID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
|
Issued: eg.Issued,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
|
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
|
||||||
@@ -152,6 +153,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
|
Issued: server.GroupIssuedAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
|
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
|
||||||
@@ -166,7 +168,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
// DeleteGroup handles group deletion request
|
// DeleteGroup handles group deletion request
|
||||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
@@ -190,8 +192,13 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteGroup(aID, groupID)
|
err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_, ok := err.(*server.GroupLinkError)
|
||||||
|
if ok {
|
||||||
|
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -237,6 +244,7 @@ func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
|
|||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
PeersCount: len(group.Peers),
|
PeersCount: len(group.Peers),
|
||||||
|
Issued: &group.Issued,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pid := range group.Peers {
|
for _, pid := range group.Peers {
|
||||||
|
|||||||
@@ -11,17 +11,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
||||||
|
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/magiconair/properties/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
var TestPeers = map[string]*server.Peer{
|
var TestPeers = map[string]*server.Peer{
|
||||||
@@ -42,9 +40,17 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
|||||||
if groupID != "idofthegroup" {
|
if groupID != "idofthegroup" {
|
||||||
return nil, status.Errorf(status.NotFound, "not found")
|
return nil, status.Errorf(status.NotFound, "not found")
|
||||||
}
|
}
|
||||||
|
if groupID == "id-jwt-group" {
|
||||||
|
return &server.Group{
|
||||||
|
ID: "id-jwt-group",
|
||||||
|
Name: "Default Group",
|
||||||
|
Issued: server.GroupIssuedJWT,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
return &server.Group{
|
return &server.Group{
|
||||||
ID: "idofthegroup",
|
ID: "idofthegroup",
|
||||||
Name: "Group",
|
Name: "Group",
|
||||||
|
Issued: server.GroupIssuedAPI,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
||||||
@@ -80,11 +86,24 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
|||||||
user.Id: user,
|
user.Id: user,
|
||||||
},
|
},
|
||||||
Groups: map[string]*server.Group{
|
Groups: map[string]*server.Group{
|
||||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
|
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT},
|
||||||
"id-all": {ID: "id-all", Name: "All"},
|
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI},
|
||||||
|
"id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI},
|
||||||
},
|
},
|
||||||
}, user, nil
|
}, user, nil
|
||||||
},
|
},
|
||||||
|
DeleteGroupFunc: func(accountID, userId, groupID string) error {
|
||||||
|
if groupID == "linked-grp" {
|
||||||
|
return &server.GroupLinkError{
|
||||||
|
Resource: "something",
|
||||||
|
Name: "linked-grp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if groupID == "invalid-grp" {
|
||||||
|
return fmt.Errorf("internal error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
@@ -169,6 +188,8 @@ func TestGetGroup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteGroup(t *testing.T) {
|
func TestWriteGroup(t *testing.T) {
|
||||||
|
groupIssuedAPI := "api"
|
||||||
|
groupIssuedJWT := "jwt"
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
@@ -189,6 +210,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
expectedGroup: &api.Group{
|
expectedGroup: &api.Group{
|
||||||
Id: "id-was-set",
|
Id: "id-was-set",
|
||||||
Name: "Default POSTed Group",
|
Name: "Default POSTed Group",
|
||||||
|
Issued: &groupIssuedAPI,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -210,6 +232,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
expectedGroup: &api.Group{
|
expectedGroup: &api.Group{
|
||||||
Id: "id-existed",
|
Id: "id-existed",
|
||||||
Name: "Default POSTed Group",
|
Name: "Default POSTed Group",
|
||||||
|
Issued: &groupIssuedAPI,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -230,6 +253,19 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Write Group PUT not not change Issue",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/groups/id-jwt-group",
|
||||||
|
requestBody: bytes.NewBuffer(
|
||||||
|
[]byte(`{"Name":"changed","Issued":"api"}`)),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedGroup: &api.Group{
|
||||||
|
Id: "id-jwt-group",
|
||||||
|
Name: "changed",
|
||||||
|
Issued: &groupIssuedJWT,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user")
|
||||||
@@ -271,3 +307,79 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteGroup(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatus int
|
||||||
|
expectedBody bool
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Try to delete linked group",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/groups/linked-grp",
|
||||||
|
expectedStatus: http.StatusBadRequest,
|
||||||
|
expectedBody: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Try to cause internal error",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/groups/invalid-grp",
|
||||||
|
expectedStatus: http.StatusInternalServerError,
|
||||||
|
expectedBody: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Try to cause internal error",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/groups/invalid-grp",
|
||||||
|
expectedStatus: http.StatusInternalServerError,
|
||||||
|
expectedBody: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Delete group",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/groups/any-grp",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
adminUser := server.NewAdminUser("test_user")
|
||||||
|
p := initGroupTestData(adminUser)
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("I don't know what I expected; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := recorder.Code; status != tc.expectedStatus {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||||
|
status, tc.expectedStatus, string(content))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectedBody {
|
||||||
|
got := &util.ErrorResponse{}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(content, &got); err != nil {
|
||||||
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, got.Code, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
|
|||||||
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
||||||
|
apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
|
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
case "bearer":
|
case "bearer":
|
||||||
err := m.CheckJWTFromRequest(w, r)
|
err := m.CheckJWTFromRequest(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Error when validating JWT claims: %s", err.Error())
|
log.Errorf("Error when validating JWT claims: %s", err.Error())
|
||||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -208,6 +208,37 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(w, users)
|
util.WriteJSONObject(w, users)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUser resend invitations to users who haven't activated their accounts,
|
||||||
|
// prior to the expiration period.
|
||||||
|
func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
targetUserID := vars["userId"]
|
||||||
|
if len(targetUserID) == 0 {
|
||||||
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(w, emptyObject{})
|
||||||
|
}
|
||||||
|
|
||||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||||
autoGroups := user.AutoGroups
|
autoGroups := user.AutoGroups
|
||||||
if autoGroups == nil {
|
if autoGroups == nil {
|
||||||
|
|||||||
@@ -98,6 +98,17 @@ func initUsersTestData() *UsersHandler {
|
|||||||
}
|
}
|
||||||
return info, nil
|
return info, nil
|
||||||
},
|
},
|
||||||
|
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
|
if initiatorUserID != existingUserID {
|
||||||
|
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetUserID == notFoundUserID {
|
||||||
|
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
@@ -340,6 +351,51 @@ func TestCreateUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInviteUser(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatus int
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
requestVars map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Invite User with Existing User",
|
||||||
|
requestType: http.MethodPost,
|
||||||
|
requestPath: "/api/users/" + existingUserID + "/invite",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
requestVars: map[string]string{"userId": existingUserID},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invite User with missing user_id",
|
||||||
|
requestType: http.MethodPost,
|
||||||
|
requestPath: "/api/users/" + notFoundUserID + "/invite",
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
requestVars: map[string]string{"userId": notFoundUserID},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
userHandler := initUsersTestData()
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
req = mux.SetURLVars(req, tc.requestVars)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
userHandler.InviteUser(rr, req)
|
||||||
|
|
||||||
|
res := rr.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if status := rr.Code; status != tc.expectedStatus {
|
||||||
|
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||||
|
status, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDeleteUser(t *testing.T) {
|
func TestDeleteUser(t *testing.T) {
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
|
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
|
||||||
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -58,14 +63,9 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
|
|||||||
|
|
||||||
// WriteErrorResponse prepares and writes an error response i nJSON
|
// WriteErrorResponse prepares and writes an error response i nJSON
|
||||||
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||||
type errorResponse struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Code int `json:"code"`
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(httpStatus)
|
w.WriteHeader(httpStatus)
|
||||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||||
err := json.NewEncoder(w).Encode(&errorResponse{
|
err := json.NewEncoder(w).Encode(&ErrorResponse{
|
||||||
Message: errMsg,
|
Message: errMsg,
|
||||||
Code: httpStatus,
|
Code: httpStatus,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -98,6 +98,11 @@ type userExportJobStatusResponse struct {
|
|||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// userVerificationJobRequest is a user verification request struct
|
||||||
|
type userVerificationJobRequest struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
// auth0Profile represents an Auth0 user profile response
|
// auth0Profile represents an Auth0 user profile response
|
||||||
type auth0Profile struct {
|
type auth0Profile struct {
|
||||||
AccountID string `json:"wt_account_id"`
|
AccountID string `json:"wt_account_id"`
|
||||||
@@ -689,6 +694,48 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
|
|||||||
return &createResp, nil
|
return &createResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||||
|
userVerificationReq := userVerificationJobRequest{
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := am.helper.Marshal(userVerificationReq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := am.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Couldn't get job response %v", err)
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error while closing invite user response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||||
// If the status is "completed", then return the downloadLink
|
// If the status is "completed", then return the downloadLink
|
||||||
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
||||||
|
|||||||
@@ -440,6 +440,12 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
|
|||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (am *AuthentikManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
||||||
jwtToken, err := am.credentials.Authenticate()
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -448,6 +448,12 @@ func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (am *AzureManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Add("$select", extensionFields)
|
q.Add("$select", extensionFields)
|
||||||
|
|||||||
350
management/server/idp/google_workspace.go
Normal file
350
management/server/idp/google_workspace.go
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
|
admin "google.golang.org/api/admin/directory/v1"
|
||||||
|
"google.golang.org/api/googleapi"
|
||||||
|
"google.golang.org/api/option"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GoogleWorkspaceManager Google Workspace manager client instance.
|
||||||
|
type GoogleWorkspaceManager struct {
|
||||||
|
usersService *admin.UsersService
|
||||||
|
CustomerID string
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
credentials ManagerCredentials
|
||||||
|
helper ManagerHelper
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// GoogleWorkspaceClientConfig Google Workspace manager client configurations.
|
||||||
|
type GoogleWorkspaceClientConfig struct {
|
||||||
|
ServiceAccountKey string
|
||||||
|
CustomerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GoogleWorkspaceCredentials Google Workspace authentication information.
|
||||||
|
type GoogleWorkspaceCredentials struct {
|
||||||
|
clientConfig GoogleWorkspaceClientConfig
|
||||||
|
helper ManagerHelper
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
return JWTToken{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
|
||||||
|
func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
helper := JsonParser{}
|
||||||
|
|
||||||
|
if config.CustomerID == "" {
|
||||||
|
return nil, fmt.Errorf("google IdP configuration is incomplete, CustomerID is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := &GoogleWorkspaceCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: httpClient,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new Admin SDK Directory service client
|
||||||
|
adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := admin.NewService(context.Background(),
|
||||||
|
option.WithScopes(admin.AdminDirectoryUserScope, admin.AdminDirectoryUserschemaScope),
|
||||||
|
option.WithCredentials(adminCredentials),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = configureAppMetadataSchema(service, config.CustomerID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GoogleWorkspaceManager{
|
||||||
|
usersService: service.Users,
|
||||||
|
CustomerID: config.CustomerID,
|
||||||
|
httpClient: httpClient,
|
||||||
|
credentials: credentials,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||||
|
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||||
|
metadata, err := gm.helper.Marshal(appMetadata)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &admin.User{
|
||||||
|
CustomSchemas: map[string]googleapi.RawMessage{
|
||||||
|
"app_metadata": metadata,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = gm.usersService.Update(userID, user).Do()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gm.appMetrics != nil {
|
||||||
|
gm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserDataByID requests user data from Google Workspace via ID.
|
||||||
|
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||||
|
user, err := gm.usersService.Get(userID).Projection("full").Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gm.appMetrics != nil {
|
||||||
|
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseGoogleWorkspaceUser(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount returns all the users for a given profile.
|
||||||
|
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||||
|
query := fmt.Sprintf("app_metadata.wt_account_id=\"%s\"", accountID)
|
||||||
|
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Query(query).Projection("full").Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
usersData := make([]*UserData, 0)
|
||||||
|
for _, user := range usersList.Users {
|
||||||
|
userData, err := parseGoogleWorkspaceUser(user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
usersData = append(usersData, userData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return usersData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||||
|
// It returns a list of users indexed by accountID.
|
||||||
|
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||||
|
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gm.appMetrics != nil {
|
||||||
|
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||||
|
}
|
||||||
|
|
||||||
|
indexedUsers := make(map[string][]*UserData)
|
||||||
|
for _, user := range usersList.Users {
|
||||||
|
userData, err := parseGoogleWorkspaceUser(user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID := userData.AppMetadata.WTAccountID
|
||||||
|
if accountID != "" {
|
||||||
|
if _, ok := indexedUsers[accountID]; !ok {
|
||||||
|
indexedUsers[accountID] = make([]*UserData, 0)
|
||||||
|
}
|
||||||
|
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return indexedUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user in Google Workspace and sends an invitation.
|
||||||
|
func (gm *GoogleWorkspaceManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
|
||||||
|
invite := true
|
||||||
|
metadata := AppMetadata{
|
||||||
|
WTAccountID: accountID,
|
||||||
|
WTPendingInvite: &invite,
|
||||||
|
}
|
||||||
|
|
||||||
|
username := &admin.UserName{}
|
||||||
|
fields := strings.Fields(name)
|
||||||
|
if n := len(fields); n > 0 {
|
||||||
|
username.GivenName = strings.Join(fields[:n-1], " ")
|
||||||
|
username.FamilyName = fields[n-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := gm.helper.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &admin.User{
|
||||||
|
Name: username,
|
||||||
|
PrimaryEmail: email,
|
||||||
|
CustomSchemas: map[string]googleapi.RawMessage{
|
||||||
|
"app_metadata": payload,
|
||||||
|
},
|
||||||
|
Password: GeneratePassword(8, 1, 1, 1),
|
||||||
|
}
|
||||||
|
user, err = gm.usersService.Insert(user).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gm.appMetrics != nil {
|
||||||
|
gm.appMetrics.IDPMetrics().CountCreateUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseGoogleWorkspaceUser(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByEmail searches users with a given email.
|
||||||
|
// If no users have been found, this function returns an empty list.
|
||||||
|
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||||
|
user, err := gm.usersService.Get(email).Projection("full").Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gm.appMetrics != nil {
|
||||||
|
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||||
|
}
|
||||||
|
|
||||||
|
userData, err := parseGoogleWorkspaceUser(user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0)
|
||||||
|
users = append(users, userData)
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
||||||
|
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||||
|
// If that fails, it falls back to using the default Google credentials path.
|
||||||
|
// It returns the retrieved credentials or an error if unsuccessful.
|
||||||
|
func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) {
|
||||||
|
log.Debug("retrieving google credentials from the base64 encoded service account key")
|
||||||
|
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode service account key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
creds, err := google.CredentialsFromJSON(
|
||||||
|
context.Background(),
|
||||||
|
decodeKey,
|
||||||
|
admin.AdminDirectoryUserschemaScope,
|
||||||
|
admin.AdminDirectoryUserScope,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
return creds, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
||||||
|
log.Debug("falling back to default google credentials location")
|
||||||
|
|
||||||
|
creds, err = google.FindDefaultCredentials(
|
||||||
|
context.Background(),
|
||||||
|
admin.AdminDirectoryUserschemaScope,
|
||||||
|
admin.AdminDirectoryUserScope,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return creds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configureAppMetadataSchema create a custom schema for managing app metadata fields in Google Workspace.
|
||||||
|
func configureAppMetadataSchema(service *admin.Service, customerID string) error {
|
||||||
|
schemaList, err := service.Schemas.List(customerID).Do()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// checks if app_metadata schema is already created
|
||||||
|
for _, schema := range schemaList.Schemas {
|
||||||
|
if schema.SchemaName == "app_metadata" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new app_metadata schema
|
||||||
|
appMetadataSchema := &admin.Schema{
|
||||||
|
SchemaName: "app_metadata",
|
||||||
|
Fields: []*admin.SchemaFieldSpec{
|
||||||
|
{
|
||||||
|
FieldName: "wt_account_id",
|
||||||
|
FieldType: "STRING",
|
||||||
|
MultiValued: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FieldName: "wt_pending_invite",
|
||||||
|
FieldType: "BOOL",
|
||||||
|
MultiValued: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err = service.Schemas.Insert(customerID, appMetadataSchema).Do()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseGoogleWorkspaceUser parse google user to UserData.
|
||||||
|
func parseGoogleWorkspaceUser(user *admin.User) (*UserData, error) {
|
||||||
|
var appMetadata AppMetadata
|
||||||
|
|
||||||
|
// Get app metadata from custom schemas
|
||||||
|
if user.CustomSchemas != nil {
|
||||||
|
rawMessage := user.CustomSchemas["app_metadata"]
|
||||||
|
helper := JsonParser{}
|
||||||
|
|
||||||
|
if err := helper.Unmarshal(rawMessage, &appMetadata); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserData{
|
||||||
|
ID: user.Id,
|
||||||
|
Email: user.PrimaryEmail,
|
||||||
|
Name: user.Name.FullName,
|
||||||
|
AppMetadata: appMetadata,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@ type Manager interface {
|
|||||||
GetAllAccounts() (map[string][]*UserData, error)
|
GetAllAccounts() (map[string][]*UserData, error)
|
||||||
CreateUser(email string, name string, accountID string) (*UserData, error)
|
CreateUser(email string, name string, accountID string) (*UserData, error)
|
||||||
GetUserByEmail(email string) ([]*UserData, error)
|
GetUserByEmail(email string) ([]*UserData, error)
|
||||||
|
InviteUserByID(userID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientConfig defines common client configuration for all IdP manager
|
// ClientConfig defines common client configuration for all IdP manager
|
||||||
@@ -162,6 +163,12 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
APIToken: config.ExtraConfig["ApiToken"],
|
APIToken: config.ExtraConfig["ApiToken"],
|
||||||
}
|
}
|
||||||
return NewOktaManager(oktaClientConfig, appMetrics)
|
return NewOktaManager(oktaClientConfig, appMetrics)
|
||||||
|
case "google":
|
||||||
|
googleClientConfig := GoogleWorkspaceClientConfig{
|
||||||
|
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
||||||
|
CustomerID: config.ExtraConfig["CustomerId"],
|
||||||
|
}
|
||||||
|
return NewGoogleWorkspaceManager(googleClientConfig, appMetrics)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||||
|
|||||||
@@ -461,6 +461,12 @@ func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppM
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InviteUserByID resend invitations to users who haven't activated,
|
||||||
|
// their accounts prior to the expiration period.
|
||||||
|
func (km *KeycloakManager) InviteUserByID(_ string) error {
|
||||||
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
||||||
attrs := keycloakUserAttributes{}
|
attrs := keycloakUserAttributes{}
|
||||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user