mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-10 18:59:55 +00:00
Compare commits
124 Commits
test/engin
...
feature/va
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f15708d54 | ||
|
|
85ffbd1db5 | ||
|
|
abdba6c650 | ||
|
|
b025dbeb75 | ||
|
|
ffd3cd66f4 | ||
|
|
964e966269 | ||
|
|
81f69599ea | ||
|
|
356d3624c4 | ||
|
|
09c9f21a8b | ||
|
|
62899df75d | ||
|
|
eb68e35f44 | ||
|
|
4a9f755b13 | ||
|
|
f891959d64 | ||
|
|
6a8cb2b0f9 | ||
|
|
5df1973caf | ||
|
|
a6dc54c21a | ||
|
|
9b0424ea55 | ||
|
|
ea6d037b17 | ||
|
|
86ce2ed72b | ||
|
|
070e1dd890 | ||
|
|
8b61ffa78f | ||
|
|
9929b22afc | ||
|
|
5c0e4097d8 | ||
|
|
13aa9f7198 | ||
|
|
8a02d3eb9e | ||
|
|
53218f99bc | ||
|
|
e5ecf0e5b3 | ||
|
|
006524756c | ||
|
|
ced28c4376 | ||
|
|
88e4fc2245 | ||
|
|
c8d8748dcf | ||
|
|
507a40bd7f | ||
|
|
ccd4ae6315 | ||
|
|
96d2207684 | ||
|
|
f942491b91 | ||
|
|
8c8900be57 | ||
|
|
cee95461d1 | ||
|
|
49e65109d2 | ||
|
|
d93dd4fc7f | ||
|
|
3a88ac78ff | ||
|
|
da3a053e2b | ||
|
|
0e95f16cdd | ||
|
|
cd92646348 | ||
|
|
30a0d9c8c4 | ||
|
|
b2379175fe | ||
|
|
09bdd271f1 | ||
|
|
a42ebb8202 | ||
|
|
208a2b7169 | ||
|
|
15b83cb1e5 | ||
|
|
fdb1a1fe00 | ||
|
|
8284ae959c | ||
|
|
8cabb07728 | ||
|
|
57f7f43ecb | ||
|
|
2e20a586cb | ||
|
|
ed3c3c214e | ||
|
|
bdf114cd74 | ||
|
|
6d985c5991 | ||
|
|
ce7de03d6e | ||
|
|
9ee08fc441 | ||
|
|
271bed5f73 | ||
|
|
2a751645f9 | ||
|
|
d4edde90c2 | ||
|
|
5cc07ba42a | ||
|
|
70f1c394c1 | ||
|
|
c74a13e1a9 | ||
|
|
1ed44b810c | ||
|
|
41acacfba5 | ||
|
|
fc7157f82f | ||
|
|
63c510e80d | ||
|
|
716009b791 | ||
|
|
a915707d13 | ||
|
|
5108888163 | ||
|
|
4e2cf9c63a | ||
|
|
5dbdeff77a | ||
|
|
7523a9e7be | ||
|
|
75ab35563a | ||
|
|
c6650705a1 | ||
|
|
f29f8c009f | ||
|
|
8826196503 | ||
|
|
ca8565de1f | ||
|
|
ac06346f5c | ||
|
|
151969bdd7 | ||
|
|
441136e2c6 | ||
|
|
ff19b237d9 | ||
|
|
376ded1b2f | ||
|
|
73b9e1c926 | ||
|
|
fb627a308c | ||
|
|
c918bab09a | ||
|
|
7fa71419cd | ||
|
|
226dc95afa | ||
|
|
1548542df3 | ||
|
|
34114d4a55 | ||
|
|
f3ec200985 | ||
|
|
8ecbe675a1 | ||
|
|
f0d91bcfc4 | ||
|
|
eb9aadfd38 | ||
|
|
8bab9dc3c0 | ||
|
|
02c0a9b1da | ||
|
|
c76cd1d86e | ||
|
|
d990d95236 | ||
|
|
cf211f6337 | ||
|
|
8d9ea40bf1 | ||
|
|
7647701898 | ||
|
|
6554b26600 | ||
|
|
8455455142 | ||
|
|
c48f244bee | ||
|
|
b7fcd0d753 | ||
|
|
a19c2f660c | ||
|
|
936215b395 | ||
|
|
bb08adcbac | ||
|
|
f5ec234f09 | ||
|
|
26f089e30d | ||
|
|
713c0341be | ||
|
|
1bbd8ae4b0 | ||
|
|
a723c424f0 | ||
|
|
3e76deaa87 | ||
|
|
36d4c21671 | ||
|
|
181e8648a8 | ||
|
|
1012c2f990 | ||
|
|
1b28d1dfbc | ||
|
|
f17016b5e5 | ||
|
|
b6cef2ce2c | ||
|
|
dedf13d8f1 | ||
|
|
d676c41c74 |
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -42,4 +42,4 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
|
||||
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.14"
|
||||
SIGN_PIPE_VER: "v0.0.16"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
@@ -223,4 +223,4 @@ jobs:
|
||||
repo: netbirdio/sign-pipelines
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||
|
||||
@@ -96,6 +96,9 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
archives:
|
||||
- builds:
|
||||
- netbird
|
||||
|
||||
@@ -23,6 +23,9 @@ builds:
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird-ui-darwin
|
||||
|
||||
archives:
|
||||
- builds:
|
||||
- netbird-ui-darwin
|
||||
|
||||
@@ -38,7 +38,7 @@ func startTestingServices(t *testing.T) string {
|
||||
signalAddr := signalLis.Addr().String()
|
||||
config.Signal.URI = signalAddr
|
||||
|
||||
_, mgmLis := startManagement(t, config, "../testdata/store.sqlite")
|
||||
_, mgmLis := startManagement(t, config, "../testdata/store.sql")
|
||||
mgmAddr := mgmLis.Addr().String()
|
||||
return mgmAddr
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir())
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,13 +22,19 @@ const (
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
spec []string
|
||||
position int
|
||||
}
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
entries map[string][][]string
|
||||
ipsetStore *ipsetStore
|
||||
entries map[string][][]string
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||
@@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
||||
wgIface: wgIface,
|
||||
routingFwChainName: routingFwChainName,
|
||||
|
||||
entries: make(map[string][][]string),
|
||||
ipsetStore: newIpsetStore(),
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
err := ipset.Init()
|
||||
@@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
||||
}
|
||||
|
||||
m.seedInitialEntries()
|
||||
m.seedInitialOptionalEntries()
|
||||
|
||||
err = m.cleanChains()
|
||||
if err != nil {
|
||||
@@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
@@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error {
|
||||
}
|
||||
}
|
||||
|
||||
for chainName, entries := range m.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(m.optionalEntries)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
|
||||
m.optionalEntries["PREROUTING"] = []entry{
|
||||
{
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
|
||||
position: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources [] netip.Prefix,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
|
||||
@@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
table := tableFilter
|
||||
if chain == chainRTNAT {
|
||||
table = tableNat
|
||||
}
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
||||
if err != nil {
|
||||
@@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
func (r *router) createContainers() error {
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
if err := r.createAndSetupChain(chain); err != nil {
|
||||
return fmt.Errorf("create chain %s: %v", chain, err)
|
||||
return fmt.Errorf("create chain %s: %w", chain, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||
return fmt.Errorf("insert established rule: %v", err)
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
return r.addJumpRules()
|
||||
if err := r.addJumpRules(); err != nil {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
@@ -432,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||
intdir := "-i"
|
||||
lointdir := "-o"
|
||||
if inverse {
|
||||
intdir = "-o"
|
||||
lointdir = "-i"
|
||||
}
|
||||
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
|
||||
@@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
sortPrefixes(sources)
|
||||
SortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
@@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
return merged
|
||||
}
|
||||
|
||||
// sortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func sortPrefixes(prefixes []netip.Prefix) {
|
||||
func SortPrefixes(prefixes []netip.Prefix) {
|
||||
sort.Slice(prefixes, func(i, j int) bool {
|
||||
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||
if addrCmp != 0 {
|
||||
|
||||
@@ -11,12 +11,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,6 +31,7 @@ const (
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameOutputFilter = "netbird-acl-output-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNamePrerouting = "netbird-rt-prerouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
@@ -40,15 +43,14 @@ var (
|
||||
)
|
||||
|
||||
type AclManager struct {
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routeingFwChainName string
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainOutputRules *nftables.Chain
|
||||
chainFwFilter *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
@@ -61,7 +63,7 @@ type iFaceMapper interface {
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for both type of operations
|
||||
@@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
|
||||
}
|
||||
|
||||
m := &AclManager{
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routeingFwChainName: routeingFwChainName,
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routingFwChainName: routingFwChainName,
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
@@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
}
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward() // to netbird-rt-fwd
|
||||
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
@@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRulesToRtForward() {
|
||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
preroutingChain := m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
})
|
||||
|
||||
m.addPreroutingRule(preroutingChain)
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: preroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.chainInputRules.Name,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() {
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.routeingFwChainName,
|
||||
Chain: m.routingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainFwFilter,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
@@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
return chain
|
||||
}
|
||||
|
||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
|
||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
|
||||
@@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
Exprs: getEstablishedExprs(1),
|
||||
}
|
||||
|
||||
conn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func getEstablishedExprs(register uint32) []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: register,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: register,
|
||||
DestRegister: register,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: register,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
@@ -24,7 +26,7 @@ import (
|
||||
|
||||
const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameForward = "FORWARD"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
@@ -80,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
err = r.removeAcceptForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
@@ -97,40 +99,7 @@ func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
return r.cleanUpDefaultForwardRules()
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
return r.removeAcceptForwardRules()
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
@@ -149,7 +118,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
@@ -157,25 +125,28 @@ func (r *router) createContainers() error {
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Priority: &prio,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.acceptForwardRules()
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -188,6 +159,7 @@ func (r *router) AddRouteFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
@@ -248,9 +220,18 @@ func (r *router) AddRouteFiltering(
|
||||
UserData: []byte(ruleKey),
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
|
||||
rule = r.conn.AddRule(rule)
|
||||
|
||||
return ruleKey, r.conn.Flush()
|
||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
@@ -288,6 +269,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if nftRule.Handle == 0 {
|
||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||
}
|
||||
|
||||
setName := r.findSetNameInRule(nftRule)
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
@@ -440,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
dir := expr.MetaKeyIIFNAME
|
||||
notDir := expr.MetaKeyOIFNAME
|
||||
if pair.Inverse {
|
||||
dir = expr.MetaKeyOIFNAME
|
||||
notDir = expr.MetaKeyIIFNAME
|
||||
}
|
||||
|
||||
lo := ifname("lo")
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: dir,
|
||||
@@ -455,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: notDir,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: lo,
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
@@ -562,19 +562,60 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
// that our traffic is not dropped by existing rules there.
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (r *router) acceptForwardRules() {
|
||||
func (r *router) acceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
// filter table exists but iptables is not
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptForwardRulesNftables()
|
||||
}
|
||||
|
||||
return r.acceptForwardRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) getAcceptForwardRules() [][]string {
|
||||
intf := r.wgIface.Name()
|
||||
return [][]string{
|
||||
{"-i", intf, "-j", "ACCEPT"},
|
||||
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesNftables() error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
// Rule for incoming interface (iif) with counter
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
@@ -594,6 +635,15 @@ func (r *router) acceptForwardRules() {
|
||||
}
|
||||
r.conn.InsertRule(iifRule)
|
||||
|
||||
oifExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
|
||||
// Rule for outgoing interface (oif) with counter
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
@@ -604,36 +654,72 @@ func (r *router) acceptForwardRules() {
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 2,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
|
||||
r.conn.InsertRule(oifRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptForwardRulesNftables()
|
||||
}
|
||||
|
||||
return r.removeAcceptForwardRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||
@@ -658,7 +744,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
@@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
)
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
@@ -314,6 +326,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
@@ -346,10 +362,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
// Verify actual nftables rule content
|
||||
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
@@ -21,5 +24,41 @@ func GenerateRouteRuleKey(
|
||||
dPort *manager.Port,
|
||||
action manager.Action,
|
||||
) RuleID {
|
||||
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
|
||||
manager.SortPrefixes(sources)
|
||||
|
||||
h := sha256.New()
|
||||
|
||||
// Write all fields to the hasher, with delimiters
|
||||
h.Write([]byte("sources:"))
|
||||
for _, src := range sources {
|
||||
h.Write([]byte(src.String()))
|
||||
h.Write([]byte(","))
|
||||
}
|
||||
|
||||
h.Write([]byte("destination:"))
|
||||
h.Write([]byte(destination.String()))
|
||||
|
||||
h.Write([]byte("proto:"))
|
||||
h.Write([]byte(proto))
|
||||
|
||||
h.Write([]byte("sPort:"))
|
||||
if sPort != nil {
|
||||
h.Write([]byte(sPort.String()))
|
||||
} else {
|
||||
h.Write([]byte("<nil>"))
|
||||
}
|
||||
|
||||
h.Write([]byte("dPort:"))
|
||||
if dPort != nil {
|
||||
h.Write([]byte(dPort.String()))
|
||||
} else {
|
||||
h.Write([]byte("<nil>"))
|
||||
}
|
||||
|
||||
h.Write([]byte("action:"))
|
||||
h.Write([]byte(strconv.Itoa(int(action))))
|
||||
hash := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
// prepend destination prefix to be able to identify the rule
|
||||
return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
|
||||
}
|
||||
|
||||
@@ -832,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite")
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1080,7 +1080,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -82,8 +82,6 @@ type Conn struct {
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
@@ -106,7 +104,8 @@ type Conn struct {
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
|
||||
endpointRelay *net.UDPAddr
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
|
||||
// for reconnection operations
|
||||
iCEDisconnected chan bool
|
||||
@@ -257,8 +256,7 @@ func (conn *Conn) Close() {
|
||||
conn.wgProxyICE = nil
|
||||
}
|
||||
|
||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
if err != nil {
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
@@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
|
||||
conn.log.Debugf("ICE connection is ready")
|
||||
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
|
||||
defer conn.updateIceState(iceConnInfo)
|
||||
|
||||
if conn.currentConnPriority > priority {
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.updateIceState(iceConnInfo)
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Infof("set ICE to active connection")
|
||||
|
||||
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
|
||||
if err != nil {
|
||||
return
|
||||
var (
|
||||
ep *net.UDPAddr
|
||||
wgProxy wgproxy.Proxy
|
||||
err error
|
||||
)
|
||||
if iceConnInfo.RelayedOnLocal {
|
||||
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
ep = wgProxy.EndpointAddr()
|
||||
conn.wgProxyICE = wgProxy
|
||||
} else {
|
||||
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolveUDPaddr")
|
||||
conn.handleConfigurationFailure(err, nil)
|
||||
return
|
||||
}
|
||||
ep = directEp
|
||||
}
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
|
||||
|
||||
conn.connIDICE = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
|
||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
||||
if err != nil {
|
||||
if wgProxy != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close turn connection: %v", err)
|
||||
}
|
||||
}
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Pause()
|
||||
}
|
||||
|
||||
if wgProxy != nil {
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
if err = conn.configureWGEndpoint(ep); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyICE = wgProxy
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
}
|
||||
|
||||
@@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
|
||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// switch back to relay connection
|
||||
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
|
||||
if conn.isReadyToUpgrade() {
|
||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
||||
err := conn.configureWGEndpoint(conn.endpointRelay)
|
||||
if err != nil {
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
@@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||
conn.statusICE.Set(newState)
|
||||
|
||||
select {
|
||||
case conn.iCEDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
conn.notifyReconnectLoopICEDisconnected(changed)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.ctx.Err() != nil {
|
||||
if err := rci.relayedConn.Close(); err != nil {
|
||||
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Debugf("Relay connection is ready to use")
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
|
||||
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.endpointRelay = endpointUdpAddr
|
||||
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
|
||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
|
||||
if conn.currentConnPriority > connPriorityRelay {
|
||||
if conn.statusICE.Get() == StatusConnected {
|
||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
return
|
||||
}
|
||||
if conn.iceP2PIsActive() {
|
||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
}
|
||||
|
||||
conn.connIDRelay = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
||||
if err != nil {
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
wgConfigWorkaround()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
}
|
||||
@@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("relay connection is disconnected")
|
||||
conn.log.Debugf("relay connection is disconnected")
|
||||
|
||||
if conn.currentConnPriority == connPriorityRelay {
|
||||
log.Debugf("clean up WireGuard config")
|
||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
if err != nil {
|
||||
conn.log.Debugf("clean up WireGuard config")
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.endpointRelay = nil
|
||||
_ = conn.wgProxyRelay.CloseConn()
|
||||
conn.wgProxyRelay = nil
|
||||
}
|
||||
|
||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
|
||||
select {
|
||||
case conn.relayDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
conn.notifyReconnectLoopRelayDisconnected(changed)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
Relayed: conn.isRelayed(),
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||
if err != nil {
|
||||
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
||||
conn.connIDICE = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDICE, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDRelay != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
@@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
|
||||
if !iceConnInfo.RelayedOnLocal {
|
||||
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
||||
}
|
||||
conn.log.Debugf("setup ice turn connection")
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.log.Debugf("setup proxied WireGuard connection")
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
if errClose := wgProxy.CloseConn(); errClose != nil {
|
||||
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
||||
}
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||
}
|
||||
|
||||
func (conn *Conn) iceP2PIsActive() bool {
|
||||
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
|
||||
select {
|
||||
case conn.relayDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
|
||||
select {
|
||||
case conn.iCEDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if wgProxy != nil {
|
||||
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||
}
|
||||
}
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Work()
|
||||
}
|
||||
return ep, wgProxy, nil
|
||||
}
|
||||
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
|
||||
@@ -5,7 +5,6 @@ package ebpf
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
}
|
||||
|
||||
// AddTurnConn add new turn connection for the proxy
|
||||
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
|
||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
|
||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
|
||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||
|
||||
wgEndpoint := &net.UDPAddr{
|
||||
@@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
|
||||
defer p.removeTurnConn(endpointPort)
|
||||
|
||||
var (
|
||||
err error
|
||||
n int
|
||||
)
|
||||
buf := make([]byte, 1500)
|
||||
for ctx.Err() == nil {
|
||||
n, err = remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err != io.EOF {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
|
||||
if ctx.Err() != nil || p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
@@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
@@ -4,8 +4,13 @@ package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
@@ -13,20 +18,55 @@ type ProxyWrapper struct {
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgEndpointAddr *net.UDPAddr
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
||||
ctxConn, cancel := context.WithCancel(ctx)
|
||||
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("add turn conn: %w", err)
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
e.remoteConn = remoteConn
|
||||
e.cancel = cancel
|
||||
return addr, err
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgEndpointAddr = addr
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
@@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.readFromRemote(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
|
||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||
type Proxy interface {
|
||||
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
|
||||
AddTurnConn(ctx context.Context, turnConn net.Conn) error
|
||||
EndpointAddr() *net.UDPAddr
|
||||
Work()
|
||||
Pause()
|
||||
CloseConn() error
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
@@ -15,13 +15,17 @@ import (
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
localWGListenPort int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||
@@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn start the proxy with the given remote conn
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
// AddTurnConn
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
go p.proxyToRemote()
|
||||
go p.proxyToLocal()
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return p.localConn.LocalAddr(), err
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
return endpointUdpAddr
|
||||
}
|
||||
|
||||
// Work starts the proxy or resumes it if it was paused
|
||||
func (p *WGUserSpaceProxy) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
func (p *WGUserSpaceProxy) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
@@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
|
||||
}
|
||||
|
||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to remote loop: %s", err)
|
||||
@@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
for ctx.Err() == nil {
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
@@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
}
|
||||
|
||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to local loop: %s", err)
|
||||
@@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
for {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||
|
||||
@@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
36
client/testdata/store.sql
vendored
Normal file
36
client/testdata/store.sql
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN TRANSACTION;
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
|
||||
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
|
||||
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
|
||||
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
|
||||
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
|
||||
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
|
||||
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
|
||||
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
COMMIT;
|
||||
BIN
client/testdata/store.sqlite
vendored
BIN
client/testdata/store.sqlite
vendored
Binary file not shown.
@@ -1,12 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>CFBundleExecutable</key>
|
||||
<string>netbird-ui</string>
|
||||
<key>CFBundleIconFile</key>
|
||||
<string>Netbird</string>
|
||||
<key>LSUIElement</key>
|
||||
<string>1</string>
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -8,11 +8,11 @@ cask "{{ $projectName }}" do
|
||||
if Hardware::CPU.intel?
|
||||
url "{{ $amdURL }}"
|
||||
sha256 "{{ crypto.SHA256 $amdFileBytes }}"
|
||||
app "netbird_ui_darwin_amd64", target: "Netbird UI.app"
|
||||
app "netbird_ui_darwin", target: "Netbird UI.app"
|
||||
else
|
||||
url "{{ $armURL }}"
|
||||
sha256 "{{ crypto.SHA256 $armFileBytes }}"
|
||||
app "netbird_ui_darwin_arm64", target: "Netbird UI.app"
|
||||
app "netbird_ui_darwin", target: "Netbird UI.app"
|
||||
end
|
||||
|
||||
depends_on formula: "netbird"
|
||||
@@ -36,4 +36,4 @@ cask "{{ $projectName }}" do
|
||||
name "Netbird UI"
|
||||
desc "Netbird UI Client"
|
||||
homepage "https://www.netbird.io/"
|
||||
end
|
||||
end
|
||||
|
||||
23
go.mod
23
go.mod
@@ -19,8 +19,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/crypto v0.24.0
|
||||
golang.org/x/sys v0.21.0
|
||||
golang.org/x/crypto v0.28.0
|
||||
golang.org/x/sys v0.26.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
@@ -38,6 +38,7 @@ require (
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/eko/gocache/v3 v3.1.1
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/gliderlabs/ssh v0.3.4
|
||||
@@ -45,7 +46,7 @@ require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/google/nftables v0.2.0
|
||||
github.com/gopacket/gopacket v1.1.1
|
||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
@@ -55,12 +56,12 @@ require (
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/magiconair/properties v1.8.7
|
||||
github.com/mattn/go-sqlite3 v1.14.19
|
||||
github.com/mdlayher/socket v0.4.1
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
@@ -70,6 +71,7 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/r3labs/diff/v3 v3.0.1
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
@@ -89,10 +91,10 @@ require (
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/net v0.26.0
|
||||
golang.org/x/net v0.30.0
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/term v0.21.0
|
||||
golang.org/x/sync v0.8.0
|
||||
golang.org/x/term v0.25.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
@@ -133,7 +135,6 @@ require (
|
||||
github.com/containerd/containerd v1.7.16 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
@@ -210,6 +211,8 @@ require (
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
@@ -219,7 +222,7 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/text v0.19.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
|
||||
42
go.sum
42
go.sum
@@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
|
||||
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
|
||||
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
||||
@@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
@@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
@@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
|
||||
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
@@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -774,8 +780,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -871,8 +877,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -901,8 +907,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -974,8 +980,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@@ -983,8 +989,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
|
||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
|
||||
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
|
||||
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -999,8 +1005,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
||||
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
||||
@@ -873,7 +873,7 @@ services:
|
||||
zitadel:
|
||||
restart: 'always'
|
||||
networks: [netbird]
|
||||
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
|
||||
image: 'ghcr.io/zitadel/zitadel:v2.54.10'
|
||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||
env_file:
|
||||
- ./zitadel.env
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -58,7 +57,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite")
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -514,22 +513,3 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
||||
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
|
||||
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
|
||||
}
|
||||
|
||||
func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := mgmt.NewSqliteStore(ctx, dataDir, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return store, func() {
|
||||
store.Close(ctx)
|
||||
os.Remove(filepath.Join(dataDir, "store.db"))
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"math/rand"
|
||||
@@ -44,12 +45,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
@@ -98,7 +102,6 @@ type AccountManager interface {
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
||||
@@ -178,6 +181,8 @@ type DefaultAccountManager struct {
|
||||
dnsDomain string
|
||||
peerLoginExpiry Scheduler
|
||||
|
||||
peerInactivityExpiry Scheduler
|
||||
|
||||
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
|
||||
userDeleteFromIDPEnabled bool
|
||||
|
||||
@@ -195,6 +200,13 @@ type Settings struct {
|
||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||
PeerLoginExpiration time.Duration
|
||||
|
||||
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
|
||||
PeerInactivityExpirationEnabled bool
|
||||
|
||||
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
|
||||
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
|
||||
PeerInactivityExpiration time.Duration
|
||||
|
||||
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
|
||||
RegularUsersViewBlocked bool
|
||||
|
||||
@@ -225,6 +237,9 @@ func (s *Settings) Copy() *Settings {
|
||||
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
||||
JWTAllowGroups: s.JWTAllowGroups,
|
||||
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: s.PeerInactivityExpiration,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
@@ -606,6 +621,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer {
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetInactivePeers returns peers that have been expired by inactivity
|
||||
func (a *Account) GetInactivePeers() []*nbpeer.Peer {
|
||||
var peers []*nbpeer.Peer
|
||||
for _, inactivePeer := range a.GetPeersWithInactivity() {
|
||||
inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration)
|
||||
if inactive {
|
||||
peers = append(peers, inactivePeer)
|
||||
}
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
// If there is no peer that expires this function returns false and a duration of 0.
|
||||
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||
func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) {
|
||||
peersWithExpiry := a.GetPeersWithInactivity()
|
||||
if len(peersWithExpiry) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
var nextExpiry *time.Duration
|
||||
for _, peer := range peersWithExpiry {
|
||||
if peer.Status.LoginExpired || peer.Status.Connected {
|
||||
continue
|
||||
}
|
||||
_, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration)
|
||||
if nextExpiry == nil || duration < *nextExpiry {
|
||||
// if expiration is below 1s return 1s duration
|
||||
// this avoids issues with ticker that can't be set to < 0
|
||||
if duration < time.Second {
|
||||
return time.Second, true
|
||||
}
|
||||
nextExpiry = &duration
|
||||
}
|
||||
}
|
||||
|
||||
if nextExpiry == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return *nextExpiry, true
|
||||
}
|
||||
|
||||
// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user
|
||||
func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer {
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
for _, peer := range a.Peers {
|
||||
if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetPeers returns a list of all Account peers
|
||||
func (a *Account) GetPeers() []*nbpeer.Peer {
|
||||
var peers []*nbpeer.Peer
|
||||
@@ -972,6 +1041,7 @@ func BuildManager(
|
||||
dnsDomain: dnsDomain,
|
||||
eventStore: eventStore,
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
peerInactivityExpiry: NewDefaultScheduler(),
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
metrics: metrics,
|
||||
@@ -1100,6 +1170,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
}
|
||||
|
||||
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updatedAccount := account.UpdateSettings(newSettings)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
@@ -1110,6 +1185,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return updatedAccount, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
@@ -1145,6 +1240,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context
|
||||
}
|
||||
}
|
||||
|
||||
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
|
||||
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting account %s expiring peers", account.Id)
|
||||
return account.GetNextInactivePeerExpiration()
|
||||
}
|
||||
|
||||
expiredPeers := account.GetInactivePeers()
|
||||
var peerIDs []string
|
||||
for _, peer := range expiredPeers {
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
}
|
||||
|
||||
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
||||
|
||||
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
|
||||
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
|
||||
return account.GetNextInactivePeerExpiration()
|
||||
}
|
||||
|
||||
return account.GetNextInactivePeerExpiration()
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
|
||||
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
|
||||
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
|
||||
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
|
||||
}
|
||||
}
|
||||
|
||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||
// If ID is already in use (due to collision) we try one more time before returning error
|
||||
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
|
||||
@@ -1285,7 +1417,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
@@ -1300,28 +1432,39 @@ func isNil(i idp.Manager) bool {
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cachedAccount := &Account{
|
||||
Id: accountID,
|
||||
Users: make(map[string]*User),
|
||||
}
|
||||
for _, user := range accountUsers {
|
||||
cachedAccount.Users[user.Id] = user
|
||||
}
|
||||
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.AppMetadata.WTAccountID == account.Id {
|
||||
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
||||
// it was already set, so we skip the unnecessary update
|
||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||
account.Id, userID)
|
||||
accountID, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
// refresh cache to reflect the update
|
||||
_, err = am.refreshCache(ctx, account.Id)
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1545,48 +1688,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
||||
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims,
|
||||
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
|
||||
if claims.Domain != "" {
|
||||
account.IsDomainPrimaryAccount = primaryDomain
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
userObj := account.Users[claims.UserId]
|
||||
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
||||
account.Domain = lowerDomain
|
||||
}
|
||||
// prevent updating category for different domain until admin logs in
|
||||
if account.Domain == lowerDomain {
|
||||
account.DomainCategory = claims.DomainCategory
|
||||
}
|
||||
} else {
|
||||
if claims.Domain == "" {
|
||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(ctx, account)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
newDomain := accountDomain
|
||||
newCategoty := domainCategory
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
if accountDomain != lowerDomain && user.HasAdminPower() {
|
||||
newDomain = lowerDomain
|
||||
}
|
||||
|
||||
if accountDomain == lowerDomain {
|
||||
newCategoty = claims.DomainCategory
|
||||
}
|
||||
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||
}
|
||||
|
||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
ctx context.Context,
|
||||
existingAcc *Account,
|
||||
primaryDomain bool,
|
||||
userAccountID string,
|
||||
domainAccountID string,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
) error {
|
||||
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
|
||||
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1594,44 +1758,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
)
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
// if domain already has a primary account, add regular user
|
||||
if domainAcc != nil {
|
||||
account = domainAcc
|
||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
account, err = am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.updateAccountDomainAttributes(ctx, account, claims, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
|
||||
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
|
||||
newAccount.Domain = lowerDomain
|
||||
newAccount.DomainCategory = claims.DomainCategory
|
||||
newAccount.IsDomainPrimaryAccount = true
|
||||
|
||||
return account, nil
|
||||
err = am.Store.SaveAccount(ctx, newAccount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
|
||||
return newAccount.Id, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
usersMap := make(map[string]*User)
|
||||
usersMap[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
err := am.Store.SaveUsers(domainAccountID, usersMap)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
|
||||
return domainAccountID, nil
|
||||
}
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
@@ -1775,7 +1953,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if claims.UserId == "" {
|
||||
return "", "", fmt.Errorf("user ID is empty")
|
||||
return "", "", errors.New(emptyUserID)
|
||||
}
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
@@ -1953,24 +2131,29 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for account: %s", accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||
// if domain is not private or domain is invalid, it will return the account ID by user ID.
|
||||
// if domain is of the PrivateCategory category, it will evaluate
|
||||
// if account is new, existing or if there is another account with the same domain
|
||||
//
|
||||
// Use cases:
|
||||
//
|
||||
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
|
||||
// New user + New account + New domain -> create account, user role = owner (if private domain, index domain)
|
||||
//
|
||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
|
||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin)
|
||||
//
|
||||
// New user + New account + Existing Public Domain -> create account, user role = admin
|
||||
// New user + New account + Existing Public Domain -> create account, user role = owner
|
||||
//
|
||||
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
|
||||
//
|
||||
@@ -1980,98 +2163,124 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
|
||||
if claims.UserId == "" {
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
return "", errors.New(emptyUserID)
|
||||
}
|
||||
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
if claims.AccountId != "" {
|
||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
|
||||
}
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||
} else if claims.AccountId != "" {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
|
||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
||||
return userAccountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
unlock := am.Store.AcquireGlobalLock(ctx)
|
||||
defer unlock()
|
||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||
if claims.AccountId != "" {
|
||||
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if err != nil {
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return "", err
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err == nil {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
||||
defer unlockAccount()
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return account.Id, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
var domainAccount *Account
|
||||
if domainAccountID != "" {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
} else {
|
||||
// other error
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != "" {
|
||||
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userAccountID, nil
|
||||
}
|
||||
|
||||
if domainAccountID != "" {
|
||||
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
|
||||
}
|
||||
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
|
||||
}
|
||||
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if domainAccountID != "" {
|
||||
return domainAccountID, nil, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain)
|
||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||
|
||||
// check again if the domain has a primary account because of simultaneous requests
|
||||
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
cancel()
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return domainAccountID, cancel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
|
||||
func handleNotFound(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
|
||||
return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
@@ -2337,6 +2546,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||
GroupsPropagationEnabled: true,
|
||||
RegularUsersViewBlocked: true,
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -139,6 +139,13 @@ const (
|
||||
PostureCheckUpdated Activity = 61
|
||||
// PostureCheckDeleted indicates that the user deleted a posture check
|
||||
PostureCheckDeleted Activity = 62
|
||||
|
||||
PeerInactivityExpirationEnabled Activity = 63
|
||||
PeerInactivityExpirationDisabled Activity = 64
|
||||
|
||||
AccountPeerInactivityExpirationEnabled Activity = 65
|
||||
AccountPeerInactivityExpirationDisabled Activity = 66
|
||||
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
@@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{
|
||||
PostureCheckCreated: {"Posture check created", "posture.check.created"},
|
||||
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
|
||||
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
|
||||
|
||||
PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"},
|
||||
PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"},
|
||||
|
||||
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
|
||||
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
82
management/server/differs/netip.go
Normal file
82
management/server/differs/netip.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package differs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
|
||||
"github.com/r3labs/diff/v3"
|
||||
)
|
||||
|
||||
// NetIPAddr is a custom differ for netip.Addr
|
||||
type NetIPAddr struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromAddr, ok1 := a.Interface().(netip.Addr)
|
||||
toAddr, ok2 := b.Interface().(netip.Addr)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromAddr.String() != toAddr.String() {
|
||||
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
|
||||
// NetIPPrefix is a custom differ for netip.Prefix
|
||||
type NetIPPrefix struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromPrefix, ok1 := a.Interface().(netip.Prefix)
|
||||
toPrefix, ok2 := b.Interface().(netip.Prefix)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromPrefix.String() != toPrefix.String() {
|
||||
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
@@ -125,26 +125,31 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
oldSettings := account.DNSSettings.Copy()
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for account dns settings: %s", accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -210,7 +212,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -476,3 +478,145 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
|
||||
t.Run("saving dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing group with peers from DNS settings should trigger updates to account peers and send peer updates
|
||||
t.Run("removing group with peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
account.Settings = &Settings{
|
||||
PeerLoginExpirationEnabled: false,
|
||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -121,12 +121,21 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
newGroupIDs := make([]string, 0, len(newGroups))
|
||||
for _, newGroup := range newGroups {
|
||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for group: %v", newGroupIDs)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
@@ -238,8 +247,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -282,8 +289,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, us
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
@@ -336,7 +341,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for group: %s", groupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -366,7 +375,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
}
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for group: %s", groupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -469,3 +482,32 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -44,3 +44,8 @@ func (g *Group) Copy() *Group {
|
||||
copy(group.Peers, g.Peers)
|
||||
return group
|
||||
}
|
||||
|
||||
// HasPeers checks if the group has any peers.
|
||||
func (g *Group) HasPeers() bool {
|
||||
return len(g.Peers) > 0
|
||||
}
|
||||
|
||||
@@ -4,13 +4,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -384,3 +387,312 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
||||
}
|
||||
return am, acc, nil
|
||||
}
|
||||
|
||||
func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving a group that is not linked to any resource should not update account peers
|
||||
t.Run("saving unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding a peer to a group that is not linked to any resource should not update account peers
|
||||
// and not send peer update
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.GroupAddPeer(context.Background(), account.Id, "groupB", peer3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing a peer from a group that is not linked to any resource should not update account peers
|
||||
// and not send peer update
|
||||
t.Run("removing peer from unliked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.GroupDeletePeer(context.Background(), account.Id, "groupB", peer3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting group should not update account peers and not send peer update
|
||||
t.Run("deleting group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupB")
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// adding a group to policy
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving a group linked to policy should update account peers and send peer update
|
||||
t.Run("saving linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving an unchanged group should trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// adding peer to a used group should update account peers and send peer update
|
||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.GroupAddPeer(context.Background(), account.Id, "groupA", peer3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// removing peer from a linked group should update account peers and send peer update
|
||||
t.Run("removing peer from linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.GroupDeletePeer(context.Background(), account.Id, "groupA", peer3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to name server group should update account peers and send peer update
|
||||
t.Run("saving group linked to name server group", func(t *testing.T) {
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupC"},
|
||||
true, nil, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to route should update account peers and send peer update
|
||||
t.Run("saving group linked to route", func(t *testing.T) {
|
||||
newRoute := route.Route{
|
||||
ID: "route",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"groupA"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupC"},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to dns settings should update account peers and send peer update
|
||||
t.Run("saving group linked to dns settings", func(t *testing.T) {
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupD"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||
}
|
||||
|
||||
if req.Settings.Extra != nil {
|
||||
|
||||
@@ -54,6 +54,14 @@ components:
|
||||
description: Period of time after which peer login expires (seconds).
|
||||
type: integer
|
||||
example: 43200
|
||||
peer_inactivity_expiration_enabled:
|
||||
description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
|
||||
type: boolean
|
||||
example: true
|
||||
peer_inactivity_expiration:
|
||||
description: Period of time of inactivity after which peer session expires (seconds).
|
||||
type: integer
|
||||
example: 43200
|
||||
regular_users_view_blocked:
|
||||
description: Allows blocking regular users from viewing parts of the system.
|
||||
type: boolean
|
||||
@@ -81,6 +89,8 @@ components:
|
||||
required:
|
||||
- peer_login_expiration_enabled
|
||||
- peer_login_expiration
|
||||
- peer_inactivity_expiration_enabled
|
||||
- peer_inactivity_expiration
|
||||
- regular_users_view_blocked
|
||||
AccountExtraSettings:
|
||||
type: object
|
||||
@@ -243,6 +253,9 @@ components:
|
||||
login_expiration_enabled:
|
||||
type: boolean
|
||||
example: false
|
||||
inactivity_expiration_enabled:
|
||||
type: boolean
|
||||
example: false
|
||||
approval_required:
|
||||
description: (Cloud only) Indicates whether peer needs approval
|
||||
type: boolean
|
||||
@@ -251,6 +264,7 @@ components:
|
||||
- name
|
||||
- ssh_enabled
|
||||
- login_expiration_enabled
|
||||
- inactivity_expiration_enabled
|
||||
Peer:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PeerMinimum'
|
||||
@@ -327,6 +341,10 @@ components:
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2023-05-05T09:00:35.477782Z"
|
||||
inactivity_expiration_enabled:
|
||||
description: Indicates whether peer inactivity expiration has been enabled or not
|
||||
type: boolean
|
||||
example: false
|
||||
approval_required:
|
||||
description: (Cloud only) Indicates whether peer needs approval
|
||||
type: boolean
|
||||
@@ -354,6 +372,7 @@ components:
|
||||
- last_seen
|
||||
- login_expiration_enabled
|
||||
- login_expired
|
||||
- inactivity_expiration_enabled
|
||||
- os
|
||||
- ssh_enabled
|
||||
- user_id
|
||||
|
||||
@@ -220,6 +220,12 @@ type AccountSettings struct {
|
||||
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
||||
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
||||
|
||||
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
|
||||
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
|
||||
|
||||
// PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
|
||||
PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"`
|
||||
|
||||
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
||||
PeerLoginExpiration int `json:"peer_login_expiration"`
|
||||
|
||||
@@ -538,6 +544,9 @@ type Peer struct {
|
||||
// Id Peer ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
|
||||
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||
|
||||
// Ip Peer's IP address
|
||||
Ip string `json:"ip"`
|
||||
|
||||
@@ -613,6 +622,9 @@ type PeerBatch struct {
|
||||
// Id Peer ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
|
||||
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||
|
||||
// Ip Peer's IP address
|
||||
Ip string `json:"ip"`
|
||||
|
||||
@@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string
|
||||
// PeerRequest defines model for PeerRequest.
|
||||
type PeerRequest struct {
|
||||
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
||||
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
Name string `json:"name"`
|
||||
SshEnabled bool `json:"ssh_enabled"`
|
||||
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
||||
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
Name string `json:"name"`
|
||||
SshEnabled bool `json:"ssh_enabled"`
|
||||
}
|
||||
|
||||
// PersonalAccessToken defines model for PersonalAccessToken.
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
@@ -14,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PeersHandler is a handler that returns peers of the account
|
||||
@@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
SSHEnabled: req.SshEnabled,
|
||||
Name: req.Name,
|
||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||
|
||||
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
|
||||
}
|
||||
|
||||
if req.ApprovalRequired != nil {
|
||||
@@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
}
|
||||
|
||||
return &api.Peer{
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||
KernelVersion: peer.Meta.KernelVersion,
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Version: peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
Hostname: peer.Meta.Hostname,
|
||||
UserId: peer.UserID,
|
||||
UiVersion: peer.Meta.UIVersion,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.LastLogin,
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
ApprovalRequired: !approved,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||
KernelVersion: peer.Meta.KernelVersion,
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Version: peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
Hostname: peer.Meta.Hostname,
|
||||
UserId: peer.UserID,
|
||||
UiVersion: peer.Meta.UIVersion,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.LastLogin,
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
ApprovalRequired: !approved,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
|
||||
|
||||
func Test_SyncProtocol(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
|
||||
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
@@ -413,7 +413,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile)
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -471,6 +471,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
|
||||
}
|
||||
|
||||
func Test_SyncStatusRace(t *testing.T) {
|
||||
t.Skip()
|
||||
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
|
||||
t.Skip("Skipping on CI and Postgres store")
|
||||
}
|
||||
@@ -482,9 +483,10 @@ func Test_SyncStatusRace(t *testing.T) {
|
||||
}
|
||||
func testSyncStatusRace(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Skip()
|
||||
dir := t.TempDir()
|
||||
|
||||
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
|
||||
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
@@ -627,6 +629,7 @@ func testSyncStatusRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPerformance(t *testing.T) {
|
||||
t.Skip()
|
||||
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test on CI or Windows")
|
||||
}
|
||||
@@ -655,7 +658,7 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
|
||||
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
|
||||
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
|
||||
@@ -58,7 +58,7 @@ var _ = Describe("Management service", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
config.Datadir = dataDir
|
||||
|
||||
s, listener = startServer(config, dataDir, "testdata/store.sqlite")
|
||||
s, listener = startServer(config, dataDir, "testdata/store.sql")
|
||||
addr = listener.Addr().String()
|
||||
client, conn = createRawClient(addr)
|
||||
|
||||
@@ -532,7 +532,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc.
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
s := grpc.NewServer()
|
||||
|
||||
store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir)
|
||||
store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
|
||||
@@ -57,7 +57,6 @@ type MockAccountManager struct {
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
@@ -434,14 +433,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) (
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
||||
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if am.UpdatePeerSSHKeyFunc != nil {
|
||||
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeer mocks UpdatePeerFunc function of the account manager
|
||||
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
if am.UpdatePeerFunc != nil {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -66,13 +67,15 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for ns group: %s", newNSGroup.ID)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
@@ -80,7 +83,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -98,16 +100,19 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
|
||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for ns group: %s", nsGroupToSave.ID)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
return nil
|
||||
@@ -131,13 +136,15 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
delete(account.NameServerGroups, nsGroupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for peer: %s", nsGroupID)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
return nil
|
||||
@@ -277,3 +284,11 @@ func validateDomain(domain string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false
|
||||
}
|
||||
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -773,7 +775,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -935,3 +937,179 @@ func TestValidateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
var newNameServerGroupA *nbdns.NameServerGroup
|
||||
var newNameServerGroupB *nbdns.NameServerGroup
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a nameserver group with a distribution group no peers should not update account peers
|
||||
// and not send peer update
|
||||
t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupA, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroupA", "nsGroupA", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// saving a nameserver group with a distribution group with no peers should not update account peers
|
||||
// and not send peer update
|
||||
t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupA)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Creating a nameserver group with a distribution group no peers should update account peers and send peer update
|
||||
t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroupB", "nsGroupB", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupB"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// saving a nameserver group with a distribution group with peers should update account peers and send peer update
|
||||
t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
}
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// saving unchanged nameserver group should update account peers and not send peer update
|
||||
t.Run("saving unchanged nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
}
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting a nameserver group should update account peers and send peer update
|
||||
t.Run("deleting nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeleteNameServerGroup(context.Background(), account.Id, newNameServerGroupB.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -41,9 +41,9 @@ type Network struct {
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64
|
||||
Serial uint64 `diff:"-"`
|
||||
|
||||
mu sync.Mutex `json:"-" gorm:"-"`
|
||||
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -110,6 +111,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return err
|
||||
}
|
||||
|
||||
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = time.Now().UTC()
|
||||
@@ -138,25 +164,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
}
|
||||
|
||||
if oldStatus.LoginExpired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
return oldStatus.LoginExpired, nil
|
||||
}
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
@@ -185,7 +201,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
if peer.Name != update.Name {
|
||||
peerLabelUpdated := peer.Name != update.Name
|
||||
|
||||
if peerLabelUpdated {
|
||||
peer.Name = update.Name
|
||||
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
@@ -219,6 +237,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
|
||||
|
||||
if !peer.AddedWithSSOLogin() {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
|
||||
}
|
||||
|
||||
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
|
||||
|
||||
event := activity.PeerInactivityExpirationEnabled
|
||||
if !update.InactivityExpirationEnabled {
|
||||
event = activity.PeerInactivityExpirationDisabled
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
@@ -226,7 +263,11 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if peerLabelUpdated {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for peer: %s", update.ID)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
@@ -270,6 +311,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
||||
FirewallRulesIsEmpty: true,
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
@@ -288,6 +330,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers := isPeerInActiveGroup(account, peerID)
|
||||
|
||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -298,7 +342,11 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for peer: %s", peerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -388,9 +436,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
var groupsToAdd []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
var groupsToAdd []string
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var ephemeral bool
|
||||
@@ -442,23 +490,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
newPeer = &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: freeIP,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: freeLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: registrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: freeIP,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: freeLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: registrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
InactivityExpirationEnabled: addedByUser,
|
||||
}
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||
@@ -541,7 +590,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for peer: %s", newPeer.ID)
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
@@ -862,51 +915,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey updates peer's public SSH key
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if sshKey == "" {
|
||||
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if peer.SSHKey == sshKey {
|
||||
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
peer.SSHKey = sshKey
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// trigger network map update
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
@@ -999,7 +1007,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
postureChecks := am.getPeerPostureChecks(account, p)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
@@ -1013,3 +1021,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
}
|
||||
return labelMap
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(account *Account, peerID string) bool {
|
||||
peerGroupIDs := make([]string, 0)
|
||||
for _, group := range account.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
peerGroupIDs = append(peerGroupIDs, group.ID)
|
||||
}
|
||||
}
|
||||
return areGroupChangesAffectPeers(account, peerGroupIDs)
|
||||
}
|
||||
|
||||
@@ -17,35 +17,37 @@ type Peer struct {
|
||||
// WireGuard public key
|
||||
Key string `gorm:"index"`
|
||||
// A setup key this peer was registered with
|
||||
SetupKey string
|
||||
SetupKey string `diff:"-"`
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
|
||||
// The user ID that registered the peer
|
||||
UserID string
|
||||
UserID string `diff:"-"`
|
||||
// SSHKey is a public SSH key of the peer
|
||||
SSHKey string
|
||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||
SSHEnabled bool
|
||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||
// Works with LastLogin
|
||||
LoginExpirationEnabled bool
|
||||
LoginExpirationEnabled bool `diff:"-"`
|
||||
|
||||
InactivityExpirationEnabled bool `diff:"-"`
|
||||
// LastLogin the time when peer performed last login operation
|
||||
LastLogin time.Time
|
||||
LastLogin time.Time `diff:"-"`
|
||||
// CreatedAt records the time the peer was created
|
||||
CreatedAt time.Time
|
||||
CreatedAt time.Time `diff:"-"`
|
||||
// Indicate ephemeral peer attribute
|
||||
Ephemeral bool
|
||||
Ephemeral bool `diff:"-"`
|
||||
// Geo location based on connection IP
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
@@ -187,6 +189,7 @@ func (p *Peer) Copy() *Peer {
|
||||
CreatedAt: p.CreatedAt,
|
||||
Ephemeral: p.Ephemeral,
|
||||
Location: p.Location,
|
||||
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,6 +222,22 @@ func (p *Peer) MarkLoginExpired(expired bool) {
|
||||
p.Status = newStatus
|
||||
}
|
||||
|
||||
// SessionExpired indicates whether the peer's session has expired or not.
|
||||
// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired.
|
||||
// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired).
|
||||
// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
|
||||
// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
|
||||
// Only peers added by interactive SSO login can be expired.
|
||||
func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||
if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected {
|
||||
return false, 0
|
||||
}
|
||||
expiresAt := p.Status.LastSeen.Add(expiresIn)
|
||||
now := time.Now()
|
||||
timeLeft := expiresAt.Sub(now)
|
||||
return timeLeft <= 0, timeLeft
|
||||
}
|
||||
|
||||
// LoginExpired indicates whether the peer's login has expired or not.
|
||||
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
||||
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
||||
|
||||
@@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_SessionExpired(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expirationEnabled bool
|
||||
lastLogin time.Time
|
||||
connected bool
|
||||
expected bool
|
||||
accountSettings *Settings
|
||||
}{
|
||||
{
|
||||
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
|
||||
expirationEnabled: false,
|
||||
connected: false,
|
||||
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||
accountSettings: &Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Peer Inactivity Should Expire",
|
||||
expirationEnabled: true,
|
||||
connected: false,
|
||||
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||
accountSettings: &Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Second,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Peer Inactivity Should Not Expire",
|
||||
expirationEnabled: true,
|
||||
connected: true,
|
||||
lastLogin: time.Now().UTC(),
|
||||
accountSettings: &Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Second,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range tt {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
peerStatus := &nbpeer.PeerStatus{
|
||||
Connected: c.connected,
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
InactivityExpirationEnabled: c.expirationEnabled,
|
||||
LastLogin: c.lastLogin,
|
||||
Status: peerStatus,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration)
|
||||
assert.Equal(t, expired, c.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
@@ -1004,7 +1066,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1069,7 +1131,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1135,7 +1197,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1188,6 +1250,325 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
|
||||
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
|
||||
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
||||
}
|
||||
|
||||
func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a user with auto groups
|
||||
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{
|
||||
{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
AutoGroups: []string{"groupA"},
|
||||
},
|
||||
{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
AutoGroups: []string{"groupB"},
|
||||
},
|
||||
{
|
||||
Id: "regularUser3",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
AutoGroups: []string{"groupC"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
var peer4 *nbpeer.Peer
|
||||
var peer5 *nbpeer.Peer
|
||||
var peer6 *nbpeer.Peer
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
|
||||
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to unlinked group should not update account peers and not send peer update
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with unlinked group should not update account peers and not send peer update
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating peer label should update account peers and send peer update
|
||||
t.Run("updating peer label", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
peer1.Name = "peer-1"
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with policy should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with linked group to policy should update account peers and send peer update
|
||||
t.Run("deleting peer with linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with route should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with route", func(t *testing.T) {
|
||||
route := nbroute.Route{
|
||||
ID: "testingRoute1",
|
||||
Network: netip.MustParsePrefix("100.65.250.202/32"),
|
||||
NetID: "superNet",
|
||||
NetworkType: nbroute.IPv4Network,
|
||||
PeerGroups: []string{"groupB"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupB"},
|
||||
}
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with linked group to route should update account peers and send peer update
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer5.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with name server group should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with name server group", func(t *testing.T) {
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupC"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with linked group to name server group should update account peers and send peer update
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer6.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() {
|
||||
}
|
||||
}
|
||||
|
||||
// ruleGroups returns a list of all groups referenced in the policy's rules,
|
||||
// including sources and destinations.
|
||||
func (p *Policy) ruleGroups() []string {
|
||||
groups := make([]string, 0)
|
||||
for _, rule := range p.Rules {
|
||||
groups = append(groups, rule.Sources...)
|
||||
groups = append(groups, rule.Destinations...)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// FirewallRule is a rule of the firewall.
|
||||
type FirewallRule struct {
|
||||
// PeerIP of the peer
|
||||
@@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
||||
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -363,7 +376,11 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
}
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for policy: %s", policy.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -428,7 +445,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
||||
|
||||
// savePolicy saves or updates a policy in the given account.
|
||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
|
||||
for index, rule := range policyToSave.Rules {
|
||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||
@@ -442,18 +459,25 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
|
||||
if isUpdate {
|
||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||
if policyIdx < 0 {
|
||||
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||
}
|
||||
|
||||
oldPolicy := account.Policies[policyIdx]
|
||||
// Update the existing policy
|
||||
account.Policies[policyIdx] = policyToSave
|
||||
return nil
|
||||
|
||||
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
||||
|
||||
return updateAccountPeers, nil
|
||||
}
|
||||
|
||||
// Add the new policy to the account
|
||||
account.Policies = append(account.Policies, policyToSave)
|
||||
|
||||
return nil
|
||||
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
||||
}
|
||||
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
@@ -824,3 +826,375 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int {
|
||||
return 0 // a is equal to b
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-rule-groups-no-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving policy with source group containing peers, but destination group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-has-peers-destination-none",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving policy with destination group containing peers, but source group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-destination-has-peers-source-none",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: false,
|
||||
Sources: []string{"groupC"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Disabling policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating disabled policy with destination and source groups containing peers should not update account's peers
|
||||
// or send peer update
|
||||
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Description: "updated description",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Enabling policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged policy should trigger account peers update but not send peer update
|
||||
t.Run("saving unchanged policy", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting policy should trigger account peers update and send peer update
|
||||
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policyID := "policy-source-destination-peers"
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// Deleting policy with destination group containing peers, but source group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policyID := "policy-destination-has-peers-source-none"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
||||
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||
policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -67,8 +68,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
if exists {
|
||||
|
||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for posture checks: %s", postureChecks.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -148,13 +152,9 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh
|
||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
||||
}
|
||||
|
||||
// check policy links
|
||||
for _, policy := range account.Policies {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if id == postureChecksID {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
||||
}
|
||||
}
|
||||
// Check if posture check is linked to any policy
|
||||
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
|
||||
}
|
||||
|
||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
||||
@@ -217,3 +217,25 @@ func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
|
||||
for _, policy := range account.Policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||
return true, policy
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
|
||||
if !isLinked {
|
||||
return false
|
||||
}
|
||||
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -118,3 +121,413 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
postureCheck := posture.Checks{
|
||||
ID: "postureCheck",
|
||||
Name: "postureCheck",
|
||||
AccountID: account.Id,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.28.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Saving unused posture check should not update account peers and not send peer update
|
||||
t.Run("saving unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating unused posture check should not update account peers and not send peer update
|
||||
t.Run("updating unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
policy := Policy{
|
||||
ID: "policyA",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
|
||||
// Linking posture check to policy should trigger update account peers and send peer update
|
||||
t.Run("linking posture check to policy with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked posture checks should update account peers and send peer update
|
||||
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged posture check should not trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing posture check from policy should trigger account peers update and send peer update
|
||||
t.Run("removing posture check from policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.SourcePostureChecks = []string{}
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting unused posture check should not trigger account peers update and not send peer update
|
||||
t.Run("deleting unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked posture check to policy where destination has peers but source does not
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked posture check to policy where source has peers but destination does not,
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
|
||||
account := &Account{
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policyA",
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{"checkA"},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*group.Group{
|
||||
"groupA": {
|
||||
ID: "groupA",
|
||||
Peers: []string{"peer1"},
|
||||
},
|
||||
"groupB": {
|
||||
ID: "groupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: "checkA",
|
||||
},
|
||||
{
|
||||
ID: "checkB",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check does not exist", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupB"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupA"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
account.Groups["groupA"].Peers = []string{}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -237,7 +237,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for route: %s", newRoute.ID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
@@ -313,6 +317,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute := account.Routes[routeToSave.ID]
|
||||
account.Routes[routeToSave.ID] = routeToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
@@ -320,7 +325,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for route: %s", routeToSave.ID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
@@ -350,7 +359,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if isRouteChangeAffectPeers(account, routy) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for route: %s", routy.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -641,3 +654,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers
|
||||
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -1257,7 +1258,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createRouterStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1777,3 +1778,281 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
manager, err := createRouterManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
account, err := initTestRouteAccount(t, manager)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
|
||||
})
|
||||
|
||||
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
|
||||
t.Run("creating route no routing peer and no peers in groups", func(t *testing.T) {
|
||||
route := route.Route{
|
||||
ID: "testingRoute1",
|
||||
Network: netip.MustParsePrefix("100.65.250.202/32"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"groupA"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupA"},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// Creating a route with no routing peer and having peers in groups should update account peers and send peer update
|
||||
t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) {
|
||||
route := route.Route{
|
||||
ID: "testingRoute2",
|
||||
Network: netip.MustParsePrefix("192.0.2.0/32"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{routeGroup3},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup3},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
baseRoute := route.Route{
|
||||
ID: "testingRoute3",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
|
||||
// Creating route should update account peers and send peer update
|
||||
t.Run("creating route with a routing peer", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newRoute, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
|
||||
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
|
||||
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
baseRoute = *newRoute
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating the route should update account peers and send peer update when there is peers in group
|
||||
t.Run("updating route", func(t *testing.T) {
|
||||
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating unchanged route should update account peers and not send peer update
|
||||
t.Run("updating unchanged route", func(t *testing.T) {
|
||||
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting the route should update account peers and send peer update
|
||||
t.Run("deleting route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to route peer groups that do not have any peers should update account peers and send peer update
|
||||
t.Run("adding peer to route peer groups that do not have any peers", func(t *testing.T) {
|
||||
newRoute := route.Route{
|
||||
Network: netip.MustParsePrefix("192.168.12.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"groupB"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to route groups that do not have any peers should update account peers and send peer update
|
||||
t.Run("adding peer to route groups that do not have any peers", func(t *testing.T) {
|
||||
newRoute := route.Route{
|
||||
Network: netip.MustParsePrefix("192.168.13.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"groupB"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupC"},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -323,8 +323,6 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
}
|
||||
}()
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return newKey, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
@@ -352,3 +353,73 @@ func TestSetupKey_Copy(t *testing.T) {
|
||||
key.UpdatedAt, key.AutoGroups)
|
||||
|
||||
}
|
||||
|
||||
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"group"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
var setupKey *SetupKey
|
||||
|
||||
// Creating setup key should not update account peers and not send peer update
|
||||
t.Run("creating setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving setup key should not update account peers and not send peer update
|
||||
t.Run("saving setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, setupKey, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
accountCopy := Account{
|
||||
Domain: domain,
|
||||
DomainCategory: category,
|
||||
IsDomainPrimaryAccount: isPrimaryDomain,
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "account %s", accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
@@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
@@ -911,28 +948,6 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
|
||||
func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
store, err := NewPostgresqlStore(ctx, dsn, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(ctx, fileStore.InstallationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range fileStore.GetAllAccounts(ctx) {
|
||||
err := store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB.
|
||||
func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
store, err := NewPostgresqlStore(ctx, dsn, metrics)
|
||||
@@ -1139,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||
// we may need to reconsider changing the types.
|
||||
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
|
||||
if s.storeEngine == PostgresStoreEngine {
|
||||
query = query.Order("json_array_length(peers::json) DESC")
|
||||
} else {
|
||||
query = query.Order("json_array_length(peers) DESC")
|
||||
}
|
||||
|
||||
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
|
||||
@@ -11,14 +11,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -31,7 +30,10 @@ func TestSqlite_NewStore(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStore(t)
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
@@ -39,15 +41,23 @@ func TestSqlite_NewStore(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_SaveAccount_Large(t *testing.T) {
|
||||
if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
|
||||
t.Skip("skip large test on non-linux OS due to environment restrictions")
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Run("SQLite", func(t *testing.T) {
|
||||
store := newSqliteStore(t)
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
runLargeTest(t, store)
|
||||
})
|
||||
|
||||
// create store outside to have a better time counter for the test
|
||||
store := newPostgresqlStore(t)
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
t.Run("PostgreSQL", func(t *testing.T) {
|
||||
runLargeTest(t, store)
|
||||
})
|
||||
@@ -199,7 +209,10 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStore(t)
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
@@ -213,7 +226,7 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
@@ -271,7 +284,10 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStore(t)
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testUserID := "testuser"
|
||||
user := NewAdminUser(testUserID)
|
||||
@@ -293,7 +309,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 1 {
|
||||
@@ -324,7 +340,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
var rules []*PolicyRule
|
||||
err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
|
||||
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
|
||||
|
||||
@@ -332,7 +348,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
|
||||
for _, accountUser := range account.Users {
|
||||
var pats []*PersonalAccessToken
|
||||
err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
|
||||
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
|
||||
|
||||
@@ -345,11 +361,10 @@ func TestSqlite_GetAccount(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
@@ -369,11 +384,10 @@ func TestSqlite_SavePeer(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
@@ -421,11 +435,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
@@ -478,11 +491,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -532,11 +545,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingDomain := "test.com"
|
||||
|
||||
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
|
||||
@@ -555,11 +568,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
@@ -579,11 +592,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
@@ -598,13 +611,18 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
store := newSqliteStore(t)
|
||||
// TODO: figure out why this fails on postgres
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
|
||||
err := migrate(context.Background(), store.db)
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on empty db")
|
||||
|
||||
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
||||
@@ -640,7 +658,7 @@ func TestMigrate(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err = store.db.Save(act).Error
|
||||
err = store.(*SqlStore).db.Save(act).Error
|
||||
require.NoError(t, err, "Failed to insert Gob data")
|
||||
|
||||
type route struct {
|
||||
@@ -656,16 +674,16 @@ func TestMigrate(t *testing.T) {
|
||||
Route: route2.Route{ID: "route1"},
|
||||
}
|
||||
|
||||
err = store.db.Save(rt).Error
|
||||
err = store.(*SqlStore).db.Save(rt).Error
|
||||
require.NoError(t, err, "Failed to insert Gob data")
|
||||
|
||||
err = migrate(context.Background(), store.db)
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on gob populated db")
|
||||
|
||||
err = migrate(context.Background(), store.db)
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||
|
||||
err = store.db.Delete(rt).Where("id = ?", "route1").Error
|
||||
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
|
||||
require.NoError(t, err, "Failed to delete Gob data")
|
||||
|
||||
prefix = netip.MustParsePrefix("12.0.0.0/24")
|
||||
@@ -675,13 +693,13 @@ func TestMigrate(t *testing.T) {
|
||||
Peer: "peer-id",
|
||||
}
|
||||
|
||||
err = store.db.Save(nRT).Error
|
||||
err = store.(*SqlStore).db.Save(nRT).Error
|
||||
require.NoError(t, err, "Failed to insert json nil slice data")
|
||||
|
||||
err = migrate(context.Background(), store.db)
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
|
||||
|
||||
err = migrate(context.Background(), store.db)
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||
|
||||
}
|
||||
@@ -716,63 +734,15 @@ func newAccount(store Store, id int) error {
|
||||
return store.SaveAccount(context.Background(), account)
|
||||
}
|
||||
|
||||
func newPostgresqlStore(t *testing.T) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
cleanUp, err := testutil.CreatePGDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("could not initialize postgresql store: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, store)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename)
|
||||
t.Cleanup(cleanUpQ)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cleanUpP, err := testutil.CreatePGDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUpP)
|
||||
|
||||
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, store)
|
||||
|
||||
return pstore
|
||||
}
|
||||
|
||||
func TestPostgresql_NewStore(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
@@ -780,11 +750,14 @@ func TestPostgresql_NewStore(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
@@ -798,7 +771,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
@@ -852,11 +825,14 @@ func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testUserID := "testuser"
|
||||
user := NewAdminUser(testUserID)
|
||||
@@ -878,7 +854,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 1 {
|
||||
@@ -909,7 +885,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
var rules []*PolicyRule
|
||||
err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
|
||||
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
|
||||
|
||||
@@ -917,7 +893,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
|
||||
for _, accountUser := range account.Users {
|
||||
var pats []*PersonalAccessToken
|
||||
err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
|
||||
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
|
||||
|
||||
@@ -926,11 +902,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPostgresql_SavePeerStatus(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
@@ -965,11 +944,14 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingDomain := "test.com"
|
||||
|
||||
@@ -982,11 +964,14 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
@@ -997,11 +982,14 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
@@ -1011,11 +999,8 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1059,11 +1044,8 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -1104,11 +1086,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1130,10 +1109,8 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1152,11 +1129,8 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1187,11 +1161,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
group := &nbgroup.Group{
|
||||
ID: "group-id",
|
||||
AccountID: "account-id",
|
||||
@@ -1215,3 +1191,76 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
users, err := store.GetAccountUsers(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, len(account.Users))
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
t.Run("Should update attributes with public domain", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "example.com"
|
||||
category := "public"
|
||||
IsDomainPrimaryAccount := false
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||
require.NoError(t, err)
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, domain, account.Domain)
|
||||
require.Equal(t, category, account.DomainCategory)
|
||||
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||
})
|
||||
|
||||
t.Run("Should update attributes with private domain", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "test.com"
|
||||
category := "private"
|
||||
IsDomainPrimaryAccount := true
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||
require.NoError(t, err)
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, domain, account.Domain)
|
||||
require.Equal(t, category, account.DomainCategory)
|
||||
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||
})
|
||||
|
||||
t.Run("Should fail when account does not exist", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "test.com"
|
||||
category := "private"
|
||||
IsDomainPrimaryAccount := true
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestSqlite_GetGroupByName(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "All", group.Name)
|
||||
}
|
||||
|
||||
@@ -9,10 +9,12 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
@@ -56,9 +58,11 @@ type Store interface {
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
@@ -240,30 +244,41 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestStoreFromSqlite is only used in tests
|
||||
func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
|
||||
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
|
||||
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
|
||||
// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database
|
||||
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
|
||||
kind := getStoreEngineFromEnv()
|
||||
if kind == "" {
|
||||
kind = SqliteStoreEngine
|
||||
}
|
||||
|
||||
var store *SqlStore
|
||||
var err error
|
||||
var cleanUp func()
|
||||
|
||||
if filename == "" {
|
||||
store, err = NewSqliteStore(ctx, dataDir, nil)
|
||||
cleanUp = func() {
|
||||
store.Close(ctx)
|
||||
}
|
||||
} else {
|
||||
store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename)
|
||||
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
|
||||
if runtime.GOOS == "windows" {
|
||||
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
storeStr = storeSqliteFileName
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if filename != "" {
|
||||
err = loadSQL(db, filename)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to load SQL file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
|
||||
}
|
||||
cleanUp := func() {
|
||||
store.Close(ctx)
|
||||
}
|
||||
|
||||
if kind == PostgresStoreEngine {
|
||||
cleanUp, err = testutil.CreatePGDB()
|
||||
if err != nil {
|
||||
@@ -284,21 +299,25 @@ func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string
|
||||
return store, cleanUp, nil
|
||||
}
|
||||
|
||||
func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) {
|
||||
err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
|
||||
func loadSQL(db *gorm.DB, filepath string) error {
|
||||
sqlContent, err := os.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := NewSqliteStore(ctx, dataDir, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
queries := strings.Split(string(sqlContent), ";")
|
||||
|
||||
for _, query := range queries {
|
||||
query = strings.TrimSpace(query)
|
||||
if query != "" {
|
||||
err := db.Exec(query).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return store, func() {
|
||||
store.Close(ctx)
|
||||
os.Remove(filepath.Join(dataDir, "store.db"))
|
||||
}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateFileStoreToSqlite migrates the file store to the SQLite store.
|
||||
|
||||
37
management/server/testdata/extended-store.sql
vendored
Normal file
37
management/server/testdata/extended-store.sql
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
|
||||
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
|
||||
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
|
||||
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
|
||||
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
|
||||
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
|
||||
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
|
||||
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0);
|
||||
INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0);
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,'');
|
||||
INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
BIN
management/server/testdata/extended-store.sqlite
vendored
BIN
management/server/testdata/extended-store.sqlite
vendored
Binary file not shown.
36
management/server/testdata/store.sql
vendored
Normal file
36
management/server/testdata/store.sql
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
|
||||
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
|
||||
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
|
||||
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
|
||||
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
|
||||
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
|
||||
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
|
||||
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
|
||||
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
|
||||
BIN
management/server/testdata/store.sqlite
vendored
BIN
management/server/testdata/store.sqlite
vendored
Binary file not shown.
35
management/server/testdata/store_policy_migrate.sql
vendored
Normal file
35
management/server/testdata/store_policy_migrate.sql
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
|
||||
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
|
||||
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
|
||||
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
|
||||
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
|
||||
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
|
||||
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
|
||||
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
|
||||
INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,'');
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
Binary file not shown.
35
management/server/testdata/store_with_expired_peers.sql
vendored
Normal file
35
management/server/testdata/store_with_expired_peers.sql
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
|
||||
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
|
||||
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
|
||||
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
|
||||
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
|
||||
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
|
||||
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
|
||||
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
|
||||
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
Binary file not shown.
39
management/server/testdata/storev1.sql
vendored
Normal file
39
management/server/testdata/storev1.sql
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
|
||||
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
|
||||
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
|
||||
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
|
||||
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
|
||||
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
|
||||
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
|
||||
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('auth0|61bf82ddeab084006aa1bccd','','2024-10-02 17:00:54.181873+02:00','','',0,'a443c07a-5765-4a78-97fc-390d9c1d0e49','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO accounts VALUES('google-oauth2|103201118415301331038','','2024-10-02 17:00:54.225803+02:00','','',0,'b6d0b152-364e-40c1-a8a1-fa7bcac2267f','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','One-off key','one-off','2021-12-24 16:09:45.926075752+01:00','2022-01-23 16:09:45.926075752+01:00','2021-12-24 16:09:45.926075752+01:00',0,1,'2021-12-24 16:12:45.763424077+01:00','[]',0,0);
|
||||
INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0);
|
||||
INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0);
|
||||
INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0);
|
||||
INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
BIN
management/server/testdata/storev1.sqlite
vendored
BIN
management/server/testdata/storev1.sqlite
vendored
Binary file not shown.
@@ -2,9 +2,13 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/differs"
|
||||
"github.com/r3labs/diff/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@@ -14,14 +18,17 @@ import (
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *NetworkMap
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
|
||||
peerUpdateMessage map[string]*UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.Mutex
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
@@ -29,9 +36,10 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
channelsMux: &sync.Mutex{},
|
||||
metrics: metrics,
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerUpdateMessage: make(map[string]*UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
// skip sending sync update to the peer if there is no change in update message,
|
||||
// it will not check on turn credential refresh as we do not send network map or client posture checks
|
||||
if update.NetworkMap != nil {
|
||||
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
|
||||
if !updated {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.channelsMux.Lock()
|
||||
|
||||
defer func() {
|
||||
p.channelsMux.Unlock()
|
||||
if p.metrics != nil {
|
||||
@@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
}()
|
||||
|
||||
if update.NetworkMap != nil {
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() {
|
||||
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
|
||||
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
|
||||
return
|
||||
}
|
||||
p.peerUpdateMessage[peerID] = update
|
||||
}
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
@@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
@@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
@@ -170,3 +200,77 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
|
||||
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
|
||||
p.channelsMux.RLock()
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
p.channelsMux.RUnlock()
|
||||
|
||||
if lastSentUpdate != nil {
|
||||
updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
|
||||
return false
|
||||
}
|
||||
if !updated {
|
||||
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
|
||||
func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
|
||||
}
|
||||
isNew, err = true, nil
|
||||
}()
|
||||
|
||||
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
|
||||
log.Tracef("new network map serial: %d not greater than last sent: %d, skip sending update",
|
||||
lastSentUpdate.Update.NetworkMap.GetSerial(),
|
||||
currUpdateToSend.Update.NetworkMap.GetSerial(),
|
||||
)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
differ, err := diff.NewDiffer(
|
||||
diff.CustomValueDiffers(&differs.NetIPAddr{}),
|
||||
diff.CustomValueDiffers(&differs.NetIPPrefix{}),
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create differ: %v", err)
|
||||
}
|
||||
|
||||
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
|
||||
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
|
||||
|
||||
changelog, err := differ.Diff(lastSentFiles, currFiles)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff checks: %v", err)
|
||||
}
|
||||
if len(changelog) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff network map: %v", err)
|
||||
}
|
||||
return len(changelog) > 0, nil
|
||||
}
|
||||
|
||||
// getChecksFiles returns a list of files from the given checks.
|
||||
func getChecksFiles(checks []*proto.Checks) []string {
|
||||
files := make([]string, 0, len(checks))
|
||||
for _, check := range checks {
|
||||
files = append(files, check.GetFiles()...)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
@@ -2,10 +2,19 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// var peersUpdater *PeersUpdateManager
|
||||
@@ -77,3 +86,465 @@ func TestCloseChannel(t *testing.T) {
|
||||
t.Error("Error closing the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePeerMessageUpdate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerID string
|
||||
existingUpdate *UpdateMessage
|
||||
newUpdate *UpdateMessage
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
name: "update message with turn credentials update",
|
||||
peerID: "peer",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{},
|
||||
},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message for peer without existing update",
|
||||
peerID: "peer1",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with no changes in update",
|
||||
peerID: "peer2",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "update message with changes in checks",
|
||||
peerID: "peer3",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with lower serial number",
|
||||
peerID: "peer4",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewPeersUpdateManager(nil)
|
||||
ctx := context.Background()
|
||||
|
||||
if tt.existingUpdate != nil {
|
||||
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
|
||||
}
|
||||
|
||||
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNewPeerUpdateMessage(t *testing.T) {
|
||||
t.Run("Unchanged value", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Unchanged value with serial incremented", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating routes network", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Updating routes groups", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating network map peers", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.4"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer4-key",
|
||||
DNSLabel: "peer4",
|
||||
SSHKey: "peer4-ssh-key",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating process check", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
|
||||
newUpdateMessage3 := createMockUpdateMessage(t)
|
||||
newUpdateMessage3.Update.Checks = []*proto.Checks{}
|
||||
newUpdateMessage3.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage4 := createMockUpdateMessage(t)
|
||||
check := &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/local/netbird",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage4.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage5 := createMockUpdateMessage(t)
|
||||
check = &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/local/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage5.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating DNS configuration", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newDomain := "newexample.com"
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append(
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains,
|
||||
newDomain,
|
||||
)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating peer IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Add new firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newRule := &FirewallRule{
|
||||
PeerIP: "192.168.1.3",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Action: string(PolicyTrafficActionDrop),
|
||||
Protocol: string(PolicyRuleProtocolUDP),
|
||||
Port: "53",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Removing nameserver", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating name server IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating custom DNS zone", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func createMockUpdateMessage(t *testing.T) *UpdateMessage {
|
||||
t.Helper()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer-key",
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
|
||||
secretManager := NewTimeBasedAuthSecretsManager(
|
||||
NewPeersUpdateManager(nil),
|
||||
&TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{
|
||||
Duration: defaultDuration,
|
||||
},
|
||||
Secret: "secret",
|
||||
Turns: []*Host{TurnTestHost},
|
||||
},
|
||||
&Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "secret",
|
||||
},
|
||||
)
|
||||
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipNet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
{
|
||||
ID: "route1",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
KeepRoute: true,
|
||||
NetID: "route1",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
{
|
||||
ID: "route2",
|
||||
Domains: domainList,
|
||||
KeepRoute: true,
|
||||
NetID: "route2",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
},
|
||||
DNSConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
{
|
||||
ID: "ns1",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Groups: []string{"group1"},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
checks := []*posture.Checks{
|
||||
{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
turnToken, err := secretManager.GenerateTurnToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
relayToken, err := secretManager.GenerateRelayToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &UpdateMessage{
|
||||
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
|
||||
NetworkMap: networkMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -19,10 +20,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
UserRoleOwner UserRole = "owner"
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleUnknown UserRole = "unknown"
|
||||
UserRoleOwner UserRole = "owner"
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleUnknown UserRole = "unknown"
|
||||
UserRoleBillingAdmin UserRole = "billing_admin"
|
||||
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDisabled UserStatus = "disabled"
|
||||
@@ -41,6 +43,8 @@ func StrRoleToUserRole(strRole string) UserRole {
|
||||
return UserRoleAdmin
|
||||
case "user":
|
||||
return UserRoleUser
|
||||
case "billing_admin":
|
||||
return UserRoleBillingAdmin
|
||||
default:
|
||||
return UserRoleUnknown
|
||||
}
|
||||
@@ -470,7 +474,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
|
||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -482,15 +486,24 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for user: %s", targetUserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error {
|
||||
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) {
|
||||
peers, err := account.FindUserPeers(targetUserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to find user peers")
|
||||
return false, status.Errorf(status.Internal, "failed to find user peers")
|
||||
}
|
||||
|
||||
hadPeers := len(peers) > 0
|
||||
if !hadPeers {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
@@ -498,7 +511,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
}
|
||||
|
||||
return am.deletePeers(ctx, account, peerIDs, initiatorUserID)
|
||||
return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID)
|
||||
}
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||
@@ -742,6 +755,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
updatedUsers := make([]*UserInfo, 0, len(updates))
|
||||
var (
|
||||
expiredPeers []*nbpeer.Peer
|
||||
userIDs []string
|
||||
eventsToStore []func()
|
||||
)
|
||||
|
||||
@@ -750,6 +764,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
userIDs = append(userIDs, update.Id)
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
if !addIfNotExists {
|
||||
@@ -813,8 +829,10 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.Settings.GroupsPropagationEnabled {
|
||||
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for user: %v", userIDs)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
@@ -1164,7 +1182,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
var (
|
||||
allErrors error
|
||||
updateAccountPeers bool
|
||||
)
|
||||
|
||||
deletedUsersMeta := make(map[string]map[string]any)
|
||||
for _, targetUserID := range targetUserIDs {
|
||||
@@ -1190,12 +1211,16 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
continue
|
||||
}
|
||||
|
||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if hadPeers {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
deletedUsersMeta[targetUserID] = meta
|
||||
}
|
||||
@@ -1205,7 +1230,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return fmt.Errorf("failed to delete users: %w", err)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("Skipping account peers update for user: %v", targetUserIDs)
|
||||
}
|
||||
|
||||
for targetUserID, meta := range deletedUsersMeta {
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
@@ -1214,11 +1243,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return allErrors
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
|
||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) {
|
||||
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
@@ -1229,16 +1258,16 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
||||
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||
hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
u, err := account.FindUser(targetUserID)
|
||||
@@ -1251,7 +1280,7 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
||||
tuCreatedAt = u.CreatedAt
|
||||
}
|
||||
|
||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
|
||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil
|
||||
}
|
||||
|
||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||
@@ -1330,3 +1359,13 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account.
|
||||
func areUsersLinkedToPeers(account *Account, userIDs []string) bool {
|
||||
for _, peer := range account.Peers {
|
||||
if slices.Contains(userIDs, peer.UserID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -10,9 +10,12 @@ import (
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -1264,3 +1267,165 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
// account groups propagation is enabled
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a new regular user should not update account peers and not send peer update
|
||||
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedAPI,
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// updating user with no linked peers should not update account peers and not send peer update
|
||||
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedAPI,
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// deleting user with no linked peers should not update account peers and not send peer update
|
||||
t.Run("deleting user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser1")
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// create a user and add new peer with the user
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// updating user with linked peers should update account peers and send peer update
|
||||
t.Run("updating user with linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
|
||||
})
|
||||
|
||||
// deleting user with linked peers should update account peers and send peer update
|
||||
t.Run("deleting user with linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, peer4UpdMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser2")
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -16,8 +16,10 @@ const (
|
||||
type Metrics struct {
|
||||
metric.Meter
|
||||
|
||||
TransferBytesSent metric.Int64Counter
|
||||
TransferBytesRecv metric.Int64Counter
|
||||
TransferBytesSent metric.Int64Counter
|
||||
TransferBytesRecv metric.Int64Counter
|
||||
AuthenticationTime metric.Float64Histogram
|
||||
PeerStoreTime metric.Float64Histogram
|
||||
|
||||
peers metric.Int64UpDownCounter
|
||||
peerActivityChan chan string
|
||||
@@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Metrics{
|
||||
Meter: meter,
|
||||
TransferBytesSent: bytesSent,
|
||||
TransferBytesRecv: bytesRecv,
|
||||
peers: peers,
|
||||
Meter: meter,
|
||||
TransferBytesSent: bytesSent,
|
||||
TransferBytesRecv: bytesRecv,
|
||||
AuthenticationTime: authTime,
|
||||
PeerStoreTime: peerStoreTime,
|
||||
peers: peers,
|
||||
|
||||
ctx: ctx,
|
||||
peerActivityChan: make(chan string, 10),
|
||||
@@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
|
||||
m.peerLastActive[id] = time.Time{}
|
||||
}
|
||||
|
||||
// RecordAuthenticationTime measures the time taken for peer authentication
|
||||
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
|
||||
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// RecordPeerStoreTime measures the time to store the peer in map
|
||||
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
|
||||
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
||||
func (m *Metrics) PeerDisconnected(id string) {
|
||||
m.peers.Add(m.ctx, -1)
|
||||
@@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getStandardBucketBoundaries() []float64 {
|
||||
return []float64{
|
||||
0.1,
|
||||
0.5,
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
50,
|
||||
100,
|
||||
500,
|
||||
1000,
|
||||
5000,
|
||||
10000,
|
||||
}
|
||||
}
|
||||
|
||||
153
relay/server/handshake.go
Normal file
153
relay/server/handshake.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
//nolint:staticcheck
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
)
|
||||
|
||||
// preparedMsg contains the marshalled success response messages
|
||||
type preparedMsg struct {
|
||||
responseHelloMsg []byte
|
||||
responseAuthMsg []byte
|
||||
}
|
||||
|
||||
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
|
||||
rhm, err := marshalResponseHelloMsg(instanceURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ram, err := messages.MarshalAuthResponse(instanceURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
|
||||
}
|
||||
|
||||
return &preparedMsg{
|
||||
responseHelloMsg: rhm,
|
||||
responseAuthMsg: ram,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
||||
addr := &address.Address{URL: instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal response address: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
|
||||
}
|
||||
return responseMsg, nil
|
||||
}
|
||||
|
||||
type handshake struct {
|
||||
conn net.Conn
|
||||
validator auth.Validator
|
||||
preparedMsg *preparedMsg
|
||||
|
||||
handshakeMethodAuth bool
|
||||
peerID string
|
||||
}
|
||||
|
||||
func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := h.conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
var (
|
||||
bytePeerID []byte
|
||||
peerID string
|
||||
)
|
||||
switch msgType {
|
||||
//nolint:staticcheck
|
||||
case messages.MsgTypeHello:
|
||||
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
case messages.MsgTypeAuth:
|
||||
h.handshakeMethodAuth = true
|
||||
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.peerID = peerID
|
||||
return bytePeerID, nil
|
||||
}
|
||||
|
||||
func (h *handshake) handshakeResponse() error {
|
||||
var responseMsg []byte
|
||||
if h.handshakeMethodAuth {
|
||||
responseMsg = h.preparedMsg.responseAuthMsg
|
||||
} else {
|
||||
responseMsg = h.preparedMsg.responseHelloMsg
|
||||
}
|
||||
|
||||
if _, err := h.conn.Write(responseMsg); err != nil {
|
||||
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
|
||||
//nolint:staticcheck
|
||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return rawPeerID, peerID, nil
|
||||
}
|
||||
|
||||
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
|
||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
|
||||
if err := h.validator.Validate(authPayload); err != nil {
|
||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return rawPeerID, peerID, nil
|
||||
}
|
||||
@@ -7,16 +7,13 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
//nolint:staticcheck
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
@@ -28,6 +25,7 @@ type Relay struct {
|
||||
|
||||
store *Store
|
||||
instanceURL string
|
||||
preparedMsg *preparedMsg
|
||||
|
||||
closed bool
|
||||
closeMu sync.RWMutex
|
||||
@@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
|
||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
||||
}
|
||||
|
||||
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||
if err != nil {
|
||||
metricsCancel()
|
||||
return nil, fmt.Errorf("prepare message: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
||||
|
||||
// Accept start to handle a new peer connection
|
||||
func (r *Relay) Accept(conn net.Conn) {
|
||||
acceptTime := time.Now()
|
||||
r.closeMu.RLock()
|
||||
defer r.closeMu.RUnlock()
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
|
||||
peerID, err := r.handshake(conn)
|
||||
h := handshake{
|
||||
conn: conn,
|
||||
validator: r.validator,
|
||||
preparedMsg: r.preparedMsg,
|
||||
}
|
||||
peerID, err := h.handshakeReceive()
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
if cErr := conn.Close(); cErr != nil {
|
||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||
}
|
||||
return
|
||||
@@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
|
||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
storeTime := time.Now()
|
||||
r.store.AddPeer(peer)
|
||||
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||
r.metrics.PeerConnected(peer.String())
|
||||
go func() {
|
||||
peer.Work()
|
||||
@@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
peer.log.Debugf("relay connection closed")
|
||||
r.metrics.PeerDisconnected(peer.String())
|
||||
}()
|
||||
|
||||
if err := h.handshakeResponse(); err != nil {
|
||||
log.Errorf("failed to send handshake response, close peer: %s", err)
|
||||
peer.Close()
|
||||
}
|
||||
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
|
||||
}
|
||||
|
||||
// Shutdown closes the relay server
|
||||
@@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
|
||||
func (r *Relay) InstanceURL() string {
|
||||
return r.instanceURL
|
||||
}
|
||||
|
||||
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
var (
|
||||
responseMsg []byte
|
||||
peerID []byte
|
||||
)
|
||||
switch msgType {
|
||||
//nolint:staticcheck
|
||||
case messages.MsgTypeHello:
|
||||
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
case messages.MsgTypeAuth:
|
||||
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = conn.Write(responseMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
||||
//nolint:staticcheck
|
||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
|
||||
addr := &address.Address{URL: r.instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
return rawPeerID, responseMsg, nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
|
||||
if err := r.validator.Validate(authPayload); err != nil {
|
||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
||||
}
|
||||
|
||||
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
||||
}
|
||||
|
||||
return rawPeerID, responseMsg, nil
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ download_release_binary() {
|
||||
|
||||
# Unzip the app and move to INSTALL_DIR
|
||||
unzip -q -o "$BINARY_NAME"
|
||||
mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
|
||||
mv -v "netbird_ui_${OS_TYPE}/" "$INSTALL_DIR/" || mv -v "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
|
||||
else
|
||||
${SUDO} mkdir -p "$INSTALL_DIR"
|
||||
tar -xzvf "$BINARY_NAME"
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
@@ -13,8 +14,6 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
||||
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/signal/peer"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
package util_test
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
|
||||
var (
|
||||
tmpDir string
|
||||
)
|
||||
|
||||
type TestConfig struct {
|
||||
SomeMap map[string]string
|
||||
SomeArray []string
|
||||
SomeField int
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
err := os.RemoveAll(tmpDir)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
Describe("Config", func() {
|
||||
Context("in JSON format", func() {
|
||||
It("should be written and read successfully", func() {
|
||||
|
||||
m := make(map[string]string)
|
||||
m["key1"] = "value1"
|
||||
m["key2"] = "value2"
|
||||
|
||||
arr := []string{"value1", "value2"}
|
||||
|
||||
written := &TestConfig{
|
||||
SomeMap: m,
|
||||
SomeArray: arr,
|
||||
SomeField: 99,
|
||||
}
|
||||
|
||||
err := util.WriteJson(tmpDir+"/testconfig.json", written)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(read).NotTo(BeNil())
|
||||
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
|
||||
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
|
||||
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
|
||||
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
|
||||
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Copying file contents", func() {
|
||||
Context("from one file to another", func() {
|
||||
It("should be successful", func() {
|
||||
|
||||
src := tmpDir + "/copytest_src"
|
||||
dst := tmpDir + "/copytest_dst"
|
||||
|
||||
err := util.WriteJson(src, []string{"1", "2", "3"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = util.CopyFileContents(src, dst)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
hashSrc := md5.New()
|
||||
hashDst := md5.New()
|
||||
|
||||
srcFile, err := os.Open(src)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
dstFile, err := os.Open(dst)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
_, err = io.Copy(hashSrc, srcFile)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
_, err = io.Copy(hashDst, dstFile)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = srcFile.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = dstFile.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Handle config file without full path", func() {
|
||||
Context("config file handling", func() {
|
||||
It("should be successful", func() {
|
||||
written := &TestConfig{
|
||||
SomeField: 123,
|
||||
}
|
||||
cfgFile := "test_cfg.json"
|
||||
defer os.Remove(cfgFile)
|
||||
|
||||
err := util.WriteJson(cfgFile, written)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
read, err := util.ReadJson(cfgFile, &TestConfig{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(read).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,12 +1,142 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type TestConfig struct {
|
||||
SomeMap map[string]string
|
||||
SomeArray []string
|
||||
SomeField int
|
||||
}
|
||||
|
||||
func TestConfigJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *TestConfig
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JSON config",
|
||||
config: &TestConfig{
|
||||
SomeMap: map[string]string{"key1": "value1", "key2": "value2"},
|
||||
SomeArray: []string{"value1", "value2"},
|
||||
SomeField: 99,
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
err := WriteJson(tmpDir+"/testconfig.json", tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, read)
|
||||
require.Equal(t, tt.config.SomeMap["key1"], read.(*TestConfig).SomeMap["key1"])
|
||||
require.Equal(t, tt.config.SomeMap["key2"], read.(*TestConfig).SomeMap["key2"])
|
||||
require.ElementsMatch(t, tt.config.SomeArray, read.(*TestConfig).SomeArray)
|
||||
require.Equal(t, tt.config.SomeField, read.(*TestConfig).SomeField)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFileContents(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
srcContent []string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Copy file contents successfully",
|
||||
srcContent: []string{"1", "2", "3"},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
src := tmpDir + "/copytest_src"
|
||||
dst := tmpDir + "/copytest_dst"
|
||||
|
||||
err := WriteJson(src, tt.srcContent)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = CopyFileContents(src, dst)
|
||||
require.NoError(t, err)
|
||||
|
||||
hashSrc := md5.New()
|
||||
hashDst := md5.New()
|
||||
|
||||
srcFile, err := os.Open(src)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = srcFile.Close()
|
||||
}()
|
||||
|
||||
dstFile, err := os.Open(dst)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = dstFile.Close()
|
||||
}()
|
||||
|
||||
_, err = io.Copy(hashSrc, srcFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.Copy(hashDst, dstFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, hex.EncodeToString(hashSrc.Sum(nil)[:16]), hex.EncodeToString(hashDst.Sum(nil)[:16]))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleConfigFileWithoutFullPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *TestConfig
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Handle config file without full path",
|
||||
config: &TestConfig{
|
||||
SomeField: 123,
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfgFile := "test_cfg.json"
|
||||
defer func() {
|
||||
_ = os.Remove(cfgFile)
|
||||
}()
|
||||
|
||||
err := WriteJson(cfgFile, tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
read, err := ReadJson(cfgFile, &TestConfig{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, read)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadJsonWithEnvSub(t *testing.T) {
|
||||
type Config struct {
|
||||
CertFile string `json:"CertFile"`
|
||||
|
||||
@@ -11,7 +11,8 @@ import (
|
||||
|
||||
const (
|
||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
||||
NetbirdFwmark = 0x1BD00
|
||||
NetbirdFwmark = 0x1BD00
|
||||
PreroutingFwmark = 0x1BD01
|
||||
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user