Compare commits

...

24 Commits

Author SHA1 Message Date
Bethuel Mmbaga
96d2207684 Fix JSON function compatibility for SQLite and PostgreSQL (#2746)
resolves the issue with json_array_length compatibility between SQLite and PostgreSQL. It adjusts the query to conditionally cast types:

PostgreSQL: Casts to json with ::json.
SQLite: Uses the text representation directly.
2024-10-16 17:55:30 +02:00
Emre Oksum
f942491b91 Update Zitadel version on quickstart script (#2744)
Update Zitadel version at docker compose in quickstart script from 2.54.3 to 2.54.10 because 2.54.3 isn't stable and has a lot of bugs.
2024-10-16 17:51:21 +02:00
Viktor Liu
8c8900be57 [client] Exclude loopback from NAT (#2747) 2024-10-16 17:35:59 +02:00
Maycon Santos
cee95461d1 [client] Add universal bin build and update sign workflow version (#2738)
* Add universal binaries build for macOS

* update sign pipeline version

* handle info.plist in sign workflow
2024-10-15 15:03:17 +02:00
ctrl-zzz
49e65109d2 Add session expire functionality based on inactivity (#2326)
Implemented inactivity expiration by checking the status of a peer: after a configurable period of time following netbird down, the peer shows login required.
2024-10-13 14:52:43 +02:00
Zoltan Papp
d93dd4fc7f [relay-server] Move the handshake logic to separated struct (#2648)
* Move the handshake logic to separated struct

- The server will response to the client after it ready to process the peer
- Preload the response messages

* Fix deprecated lint issue

* Fix error handling

* [relay-server] Relay measure auth time (#2675)

Measure the Relay client's authentication time
2024-10-12 18:21:34 +02:00
Viktor Liu
3a88ac78ff [client] Add table filter rules using iptables (#2727)
This specifically concerns the established/related rule since this one is not compatible with iptables-nft even if it is generated the same way by iptables-translate.
2024-10-12 10:44:48 +02:00
Maycon Santos
da3a053e2b [management] Refactor getAccountIDWithAuthorizationClaims (#2715)
This change restructures the getAccountIDWithAuthorizationClaims method to improve readability, maintainability, and performance.

- have dedicated methods to handle possible cases
- introduced Store.UpdateAccountDomainAttributes and Store.GetAccountUsers methods
- Remove GetAccount and SaveAccount dependency
- added tests
2024-10-12 08:35:51 +02:00
Zoltan Papp
0e95f16cdd [relay,client] Relay/fix/wg roaming (#2691)
If a peer connection switches from Relayed to ICE P2P, the Relayed proxy still consumes the data the other peer sends. Because the proxy is operating, the WireGuard switches back to the Relayed proxy automatically, thanks to the roaming feature.

Extend the Proxy implementation with pause/resume functions. Before switching to the p2p connection, pause the WireGuard proxy operation to prevent unnecessary package sources.
Consider waiting some milliseconds after the pause to be sure the WireGuard engine already processed all UDP msg in from the pipe.
2024-10-11 16:24:30 +02:00
pascal-fischer
b2379175fe [signal] new signal dispatcher version (#2722) 2024-10-10 16:23:46 +02:00
Viktor Liu
09bdd271f1 [client] Improve route acl (#2705)
- Update nftables library to v0.2.0
- Mark traffic that was originally destined for local and applies the input rules in the forward chain if said traffic was redirected (e.g. by Docker)
- Add nft rules to internal map only if flush was successful
- Improve error message if handle is 0 (= not found or hasn't been refreshed)
- Add debug logging when route rules are added
- Replace nftables userdata (rule ID) with a rule hash
2024-10-10 15:54:34 +02:00
Misha Bragin
208a2b7169 Add billing user role (#2714) 2024-10-10 14:14:56 +02:00
pascal-fischer
8284ae959c [management] Move testdata to sql files (#2693) 2024-10-10 12:35:03 +02:00
Maycon Santos
6ce09bca16 Add support to envsub go management configurations (#2708)
This change allows users to reference environment variables using Go template format, like {{ .EnvName }}

Moved the previous file test code to file_suite_test.go.
2024-10-09 20:46:23 +02:00
pascal-fischer
b79c1d64cc [management] Make max open db conns configurable (#2713) 2024-10-09 20:17:25 +02:00
Misha Bragin
b1eda43f4b Add Link to the Lawrence Systems video (#2711) 2024-10-09 14:56:25 +02:00
pascal-fischer
d4ef84fe6e [management] Propagate error in store errors (#2709) 2024-10-09 14:33:58 +02:00
Viktor Liu
44e8107383 [client] Limit P2P attempts and restart on specific events (#2657) 2024-10-08 11:21:11 +02:00
Bethuel Mmbaga
2c1f5e46d5 [management] Validate peer ownership during login (#2704)
* check peer ownership in login

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update error message

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-07 19:06:26 +03:00
pascal-fischer
dbec24b520 [management] Remove admin check on getAccountByID (#2699) 2024-10-06 17:01:13 +02:00
Carlos Hernandez
f603cd9202 [client] Check wginterface instead of engine ctx (#2676)
Moving code to ensure wgInterface is gone right after context is
cancelled/stop in the off chance that on next retry the backoff
operation is permanently cancelled and interface is abandoned without
destroying.
2024-10-04 19:15:16 +02:00
Bethuel Mmbaga
5897a48e29 fix wrong reference (#2695)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 18:55:25 +03:00
Bethuel Mmbaga
8bf729c7b4 [management] Add AccountExists to AccountManager (#2694)
* Add AccountExists method to account manager interface

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove unused code

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 18:09:40 +03:00
Bethuel Mmbaga
7f09b39769 [management] Refactor User JWT group sync (#2690)
* Refactor GetAccountIDByUserOrAccountID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* sync user jwt group changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* propagate jwt group changes to peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix no jwt groups synced

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests and lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Move the account peer update outside the transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move updateUserPeersInGroups to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move event store outside of transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get user with update lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run jwt sync in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 17:17:01 +03:00
80 changed files with 3729 additions and 1442 deletions

View File

@@ -42,4 +42,4 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...

View File

@@ -16,7 +16,7 @@ jobs:
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
@@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.14"
SIGN_PIPE_VER: "v0.0.15"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -20,7 +20,7 @@ concurrency:
jobs:
release:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
env:
flags: ""
steps:
@@ -223,4 +223,4 @@ jobs:
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'

View File

@@ -96,6 +96,9 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
archives:
- builds:
- netbird

View File

@@ -23,6 +23,9 @@ builds:
tags:
- load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin
archives:
- builds:
- netbird-ui-darwin

View File

@@ -49,6 +49,8 @@
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features
@@ -62,6 +64,7 @@
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@@ -38,7 +38,7 @@ func startTestingServices(t *testing.T) string {
signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr
_, mgmLis := startManagement(t, config, "../testdata/store.sqlite")
_, mgmLis := startManagement(t, config, "../testdata/store.sql")
mgmAddr := mgmLis.Addr().String()
return mgmAddr
}
@@ -71,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir())
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -21,13 +22,19 @@ const (
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
)
type entry struct {
spec []string
position int
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routingFwChainName string
entries map[string][][]string
ipsetStore *ipsetStore
entries map[string][][]string
optionalEntries map[string][]entry
ipsetStore *ipsetStore
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
@@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
wgIface: wgIface,
routingFwChainName: routingFwChainName,
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
}
err := ipset.Init()
@@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
}
m.seedInitialEntries()
m.seedInitialOptionalEntries()
err = m.cleanChains()
if err != nil {
@@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error {
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
@@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error {
}
}
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
return nil
}
@@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
}
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
position: 1,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}

View File

@@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering(
}
func (m *Manager) AddRouteFiltering(
sources [] netip.Prefix,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,

View File

@@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := tableFilter
if chain == chainRTNAT {
table = tableNat
}
table := r.getTableForChain(chain)
ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
@@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error {
func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %v", chain, err)
return fmt.Errorf("create chain %s: %w", chain, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %v", err)
return fmt.Errorf("insert established rule: %w", err)
}
return r.addJumpRules()
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
return nil
}
func (r *router) createAndSetupChain(chain string) error {
@@ -432,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {

View File

@@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
sortPrefixes(sources)
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
@@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged
}
// sortPrefixes sorts the given slice of netip.Prefix in place.
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func sortPrefixes(prefixes []netip.Prefix) {
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {

View File

@@ -11,12 +11,14 @@ import (
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -29,6 +31,7 @@ const (
chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
@@ -40,15 +43,14 @@ var (
)
type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routeingFwChainName string
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string
workTable *nftables.Table
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
@@ -61,7 +63,7 @@ type iFaceMapper interface {
IsUserspaceBind() bool
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations
@@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
}
m := &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routeingFwChainName: routeingFwChainName,
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
@@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) {
}
// netbird-acl-forward-filter
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward() // to netbird-rt-fwd
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
@@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) {
return fmt.Errorf(flushError, err)
}
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
func (m *AclManager) addJumpRulesToRtForward() {
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() {
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routeingFwChainName,
Chain: m.routingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Chain: chainFwFilter,
Exprs: expressions,
})
}
@@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
return chain
}
func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,

View File

@@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
Exprs: getEstablishedExprs(1),
}
conn.InsertRule(rule)
}
func getEstablishedExprs(register uint32) []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: register,
},
&expr.Bitwise{
SourceRegister: register,
DestRegister: register,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: register,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
}

View File

@@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) {
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},

View File

@@ -10,6 +10,8 @@ import (
"net/netip"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
@@ -24,7 +26,7 @@ import (
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
@@ -80,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
}
}
err = r.cleanUpDefaultForwardRules()
err = r.removeAcceptForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
@@ -97,40 +99,7 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.cleanUpDefaultForwardRules()
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
return r.conn.Flush()
return r.removeAcceptForwardRules()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -149,7 +118,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
@@ -157,25 +125,28 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
r.acceptForwardRules()
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
err := r.refreshRulesMap()
if err != nil {
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
@@ -188,6 +159,7 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
@@ -248,9 +220,18 @@ func (r *router) AddRouteFiltering(
UserData: []byte(ruleKey),
}
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
rule = r.conn.AddRule(rule)
return ruleKey, r.conn.Flush()
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
@@ -288,6 +269,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -440,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME
notDir := expr.MetaKeyOIFNAME
if pair.Inverse {
dir = expr.MetaKeyOIFNAME
notDir = expr.MetaKeyIIFNAME
}
lo := ifname("lo")
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{
&expr.Meta{
Key: dir,
@@ -455,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Register: 1,
Data: intf,
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: lo,
},
}
exprs = append(exprs, sourceExp...)
@@ -562,19 +562,60 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() {
func (r *router) acceptForwardRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *router) acceptForwardRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
@@ -594,6 +635,15 @@ func (r *router) acceptForwardRules() {
}
r.conn.InsertRule(iifRule)
oifExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
@@ -604,36 +654,72 @@ func (r *router) acceptForwardRules() {
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 2,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// RemoveNatRule removes a nftables rule pair from nat chains
@@ -658,7 +744,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}

View File

@@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
@@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
@@ -314,6 +326,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
@@ -346,10 +362,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
// Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
})
}
}

View File

@@ -1,8 +1,11 @@
package id
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
@@ -21,5 +24,41 @@ func GenerateRouteRuleKey(
dPort *manager.Port,
action manager.Action,
) RuleID {
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
manager.SortPrefixes(sources)
h := sha256.New()
// Write all fields to the hasher, with delimiters
h.Write([]byte("sources:"))
for _, src := range sources {
h.Write([]byte(src.String()))
h.Write([]byte(","))
}
h.Write([]byte("destination:"))
h.Write([]byte(destination.String()))
h.Write([]byte("proto:"))
h.Write([]byte(proto))
h.Write([]byte("sPort:"))
if sPort != nil {
h.Write([]byte(sPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("dPort:"))
if dPort != nil {
h.Write([]byte(dPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("action:"))
h.Write([]byte(strconv.Itoa(int(action))))
hash := hex.EncodeToString(h.Sum(nil))
// prepend destination prefix to be able to identify the rule
return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
}

View File

@@ -269,12 +269,6 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
@@ -294,6 +288,15 @@ func (c *ConnectClient) run(
}
<-engineCtx.Done()
c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown()
backOff.Reset()

View File

@@ -251,6 +251,13 @@ func (e *Engine) Stop() error {
}
log.Info("Network monitor: stopped")
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
}
err := e.removeAllPeers()
if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
@@ -1116,18 +1123,12 @@ func (e *Engine) close() {
}
}
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
}
e.wgInterface = nil
}
if !isNil(e.sshServer) {
@@ -1395,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() {
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
@@ -1426,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
}
func (e *Engine) stopDNSServer() {
if e.dnsServer == nil {
return
}
e.dnsServer.Stop()
e.dnsServer = nil
err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates {
@@ -1433,10 +1439,6 @@ func (e *Engine) stopDNSServer() {
nsGroupStates[i].Error = err
}
e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
}
// isChecksEqual checks if two slices of checks are equal.

View File

@@ -832,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite")
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
@@ -1080,7 +1080,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir)
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}

View File

@@ -32,6 +32,8 @@ const (
connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 1
connPriorityICEP2P ConnPriority = 2
reconnectMaxElapsedTime = 30 * time.Minute
)
type WgConfig struct {
@@ -80,9 +82,8 @@ type Conn struct {
config ConnConfig
statusRecorder *Status
wgProxyFactory *wgproxy.Factory
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
allowedIPsIP string
handshaker *Handshaker
@@ -103,11 +104,14 @@ type Conn struct {
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
endpointRelay *net.UDPAddr
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
// for reconnection operations
iCEDisconnected chan bool
relayDisconnected chan bool
connMonitor *ConnMonitor
reconnectCh <-chan struct{}
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -123,21 +127,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
connLog := log.WithField("peer", config.Key)
var conn = &Conn{
log: connLog,
ctx: ctx,
ctxCancel: ctxCancel,
config: config,
statusRecorder: statusRecorder,
wgProxyFactory: wgProxyFactory,
signaler: signaler,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(),
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
log: connLog,
ctx: ctx,
ctxCancel: ctxCancel,
config: config,
statusRecorder: statusRecorder,
wgProxyFactory: wgProxyFactory,
signaler: signaler,
iFaceDiscover: iFaceDiscover,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(),
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
}
conn.connMonitor, conn.reconnectCh = NewConnMonitor(
signaler,
iFaceDiscover,
config,
conn.relayDisconnected,
conn.iCEDisconnected,
)
rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected,
@@ -200,6 +214,8 @@ func (conn *Conn) startHandshakeAndReconnect() {
conn.log.Errorf("failed to send initial offer: %v", err)
}
go conn.connMonitor.Start(conn.ctx)
if conn.workerRelay.IsController() {
conn.reconnectLoopWithRetry()
} else {
@@ -240,8 +256,7 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil
}
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if err != nil {
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
@@ -309,12 +324,14 @@ func (conn *Conn) reconnectLoopWithRetry() {
// With it, we can decrease to send necessary offer
select {
case <-conn.ctx.Done():
return
case <-time.After(3 * time.Second):
}
ticker := conn.prepareExponentTicker()
defer ticker.Stop()
time.Sleep(1 * time.Second)
for {
select {
case t := <-ticker.C:
@@ -342,20 +359,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
if err != nil {
conn.log.Errorf("failed to do handshake: %v", err)
}
case changed := <-conn.relayDisconnected:
if !changed {
continue
}
conn.log.Debugf("Relay state changed, reset reconnect timer")
ticker.Stop()
ticker = conn.prepareExponentTicker()
case changed := <-conn.iCEDisconnected:
if !changed {
continue
}
conn.log.Debugf("ICE state changed, reset reconnect timer")
case <-conn.reconnectCh:
ticker.Stop()
ticker = conn.prepareExponentTicker()
case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop")
return
@@ -366,10 +374,10 @@ func (conn *Conn) reconnectLoopWithRetry() {
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.01,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: conn.config.Timeout,
MaxElapsedTime: 0,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, conn.ctx)
@@ -420,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready")
conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
if conn.currentConnPriority > priority {
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
return
}
conn.log.Infof("set ICE to active connection")
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
if err != nil {
return
var (
ep *net.UDPAddr
wgProxy wgproxy.Proxy
err error
)
if iceConnInfo.RelayedOnLocal {
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return
}
ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy
} else {
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
if err != nil {
log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil)
return
}
ep = directEp
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher()
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
if wgProxy != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close turn connection: %v", err)
}
}
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
}
if wgProxy != nil {
wgProxy.Work()
}
if err = conn.configureWGEndpoint(ep); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return
}
wgConfigWorkaround()
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyICE = wgProxy
conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
}
@@ -482,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Tracef("ICE connection state changed to %s", newState)
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
// switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
if conn.isReadyToUpgrade() {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil {
conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
@@ -496,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
select {
case conn.iCEDisconnected <- changed:
default:
}
conn.notifyReconnectLoopICEDisconnected(changed)
peerState := State{
PubKey: conn.config.Key,
@@ -520,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return
}
conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay.Set(StatusConnected)
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
wgProxy := conn.wgProxyFactory.GetProxy()
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
wgProxy, err := conn.newProxy(rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
if conn.currentConnPriority > connPriorityRelay {
if conn.statusICE.Get() == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
return
}
if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
conn.wgProxyRelay = wgProxy
conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
}
conn.connIDRelay = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround()
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = wgProxy
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
}
@@ -587,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
return
}
log.Debugf("relay connection is disconnected")
conn.log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay {
log.Debugf("clean up WireGuard config")
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if err != nil {
conn.log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
}
if conn.wgProxyRelay != nil {
conn.endpointRelay = nil
_ = conn.wgProxyRelay.CloseConn()
conn.wgProxyRelay = nil
}
changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
select {
case conn.relayDisconnected <- changed:
default:
}
conn.notifyReconnectLoopRelayDisconnected(changed)
peerState := State{
PubKey: conn.config.Key,
@@ -617,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
if err != nil {
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
}
}
@@ -755,6 +751,16 @@ func (conn *Conn) isConnected() bool {
return true
}
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
@@ -775,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
}
}
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
if !iceConnInfo.RelayedOnLocal {
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection")
wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
if err != nil {
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
}
return nil, nil, err
return nil, err
}
return wgProxy, nil
}
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
}
func (conn *Conn) iceP2PIsActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
}
func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
select {
case conn.relayDisconnected <- changed:
default:
}
}
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
select {
case conn.iCEDisconnected <- changed:
default:
}
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
}
}
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Work()
}
return ep, wgProxy, nil
}
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {

View File

@@ -0,0 +1,212 @@
package peer
import (
"context"
"fmt"
"sync"
"time"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
const (
signalerMonitorPeriod = 5 * time.Second
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ConnMonitor struct {
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
config ConnConfig
relayDisconnected chan bool
iCEDisconnected chan bool
reconnectCh chan struct{}
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
}
func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
reconnectCh := make(chan struct{}, 1)
cm := &ConnMonitor{
signaler: signaler,
iFaceDiscover: iFaceDiscover,
config: config,
relayDisconnected: relayDisconnected,
iCEDisconnected: iCEDisconnected,
reconnectCh: reconnectCh,
}
return cm, reconnectCh
}
func (cm *ConnMonitor) Start(ctx context.Context) {
signalerReady := make(chan struct{}, 1)
go cm.monitorSignalerReady(ctx, signalerReady)
localCandidatesChanged := make(chan struct{}, 1)
go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
for {
select {
case changed := <-cm.relayDisconnected:
if !changed {
continue
}
log.Debugf("Relay state changed, triggering reconnect")
cm.triggerReconnect()
case changed := <-cm.iCEDisconnected:
if !changed {
continue
}
log.Debugf("ICE state changed, triggering reconnect")
cm.triggerReconnect()
case <-signalerReady:
log.Debugf("Signaler became ready, triggering reconnect")
cm.triggerReconnect()
case <-localCandidatesChanged:
log.Debugf("Local candidates changed, triggering reconnect")
cm.triggerReconnect()
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
if cm.signaler == nil {
return
}
ticker := time.NewTicker(signalerMonitorPeriod)
defer ticker.Stop()
lastReady := true
for {
select {
case <-ticker.C:
currentReady := cm.signaler.Ready()
if !lastReady && currentReady {
select {
case signalerReady <- struct{}{}:
default:
}
}
lastReady = currentReady
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
ufrag, pwd, err := generateICECredentials()
if err != nil {
log.Warnf("Failed to generate ICE credentials: %v", err)
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
log.Warnf("Failed to handle candidate tick: %v", err)
}
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
log.Debugf("Gathering ICE candidates")
transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return fmt.Errorf("create ICE agent: %w", err)
}
defer func() {
if err := agent.Close(); err != nil {
log.Warnf("Failed to close ICE agent: %v", err)
}
}()
gatherDone := make(chan struct{})
err = agent.OnCandidate(func(c ice.Candidate) {
log.Tracef("Got candidate: %v", c)
if c == nil {
close(gatherDone)
}
})
if err != nil {
return fmt.Errorf("set ICE candidate handler: %w", err)
}
if err := agent.GatherCandidates(); err != nil {
return fmt.Errorf("gather ICE candidates: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
defer cancel()
select {
case <-ctx.Done():
return fmt.Errorf("wait for gathering: %w", ctx.Err())
case <-gatherDone:
}
candidates, err := agent.GetLocalCandidates()
if err != nil {
return fmt.Errorf("get local candidates: %w", err)
}
log.Tracef("Got candidates: %v", candidates)
if changed := cm.updateCandidates(candidates); changed {
select {
case localCandidatesChanged <- struct{}{}:
default:
}
}
return nil
}
func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
}
return false
}
func (cm *ConnMonitor) triggerReconnect() {
select {
case cm.reconnectCh <- struct{}{}:
default:
}
}

View File

@@ -6,6 +6,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList)
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(ifaceBlacklist)
}

View File

@@ -2,6 +2,6 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
}

View File

@@ -233,41 +233,16 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
transportNet, err := w.newStdNet()
transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
if err != nil {
w.log.Errorf("failed to create pion's stdnet: %s", err)
}
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: relaySupport,
InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
UDPMux: w.config.ICEConfig.UDPMux,
UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: w.localUfrag,
LocalPwd: w.localPwd,
}
if w.config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
w.sentExtraSrflx = false
agent, err := ice.NewAgent(agentConfig)
agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd)
if err != nil {
return nil, err
return nil, fmt.Errorf("create agent: %w", err)
}
err = agent.OnCandidate(w.onICECandidate)
@@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA
}
}
func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: candidateTypes,
InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList),
UDPMux: config.ICEConfig.UDPMux,
UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx,
NAT1To1IPs: config.ICEConfig.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: ufrag,
LocalPwd: pwd,
}
if config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
return ice.NewAgent(agentConfig)
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{

View File

@@ -5,7 +5,6 @@ package ebpf
import (
"context"
"fmt"
"io"
"net"
"os"
"sync"
@@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
}
// AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
@@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
return nberrors.FormatErrorOrNil(result)
}
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
@@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data)

View File

@@ -4,8 +4,13 @@ package ebpf
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -13,20 +18,55 @@ type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
}
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
ctxConn, cancel := context.WithCancel(ctx)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
cancel()
return nil, fmt.Errorf("add turn conn: %w", err)
return fmt.Errorf("add turn conn: %w", err)
}
e.remoteConn = remoteConn
e.cancel = cancel
return addr, err
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyWrapper) Pause() {
if p.remoteConn == nil {
return
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
@@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
}
return nil
}
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, 1500)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return 0, ctx.Err()
}
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}
return 0, err
}
return n, nil
}

View File

@@ -7,6 +7,9 @@ import (
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
AddTurnConn(ctx context.Context, turnConn net.Conn) error
EndpointAddr() *net.UDPAddr
Work()
Pause()
CloseConn() error
}

View File

@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}

View File

@@ -15,13 +15,17 @@ import (
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
@@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
p.ctx, p.cancel = context.WithCancel(ctx)
p.remoteConn = remoteConn
var err error
// AddTurnConn
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
return err
}
go p.proxyToRemote()
go p.proxyToLocal()
p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn
p.remoteConn = remoteConn
return p.localConn.LocalAddr(), err
return err
}
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
if p.localConn == nil {
return nil
}
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
return endpointUdpAddr
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUserSpaceProxy) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
}
// Pause pauses the proxy from receiving data from the remote peer
func (p *WGUserSpaceProxy) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
// CloseConn close the localConn
@@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() {
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
@@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
for ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
@@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
@@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}
// proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() {
// if the proxy is paused it will drain the remote conn and drop the packets
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
@@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
for {
n, err := p.remoteConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
_, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)

View File

@@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir)
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}

36
client/testdata/store.sql vendored Normal file
View File

@@ -0,0 +1,36 @@
PRAGMA foreign_keys=OFF;
BEGIN TRANSACTION;
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO installations VALUES(1,'');
COMMIT;

Binary file not shown.

View File

@@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleExecutable</key>
<string>netbird-ui</string>
<key>CFBundleIconFile</key>
<string>Netbird</string>
<key>LSUIElement</key>
<string>1</string>
</dict>
</plist>

20
go.mod
View File

@@ -19,8 +19,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.24.0
golang.org/x/sys v0.21.0
golang.org/x/crypto v0.28.0
golang.org/x/sys v0.26.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -38,6 +38,7 @@ require (
github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/davecgh/go-spew v1.1.1
github.com/eko/gocache/v3 v3.1.1
github.com/fsnotify/fsnotify v1.7.0
github.com/gliderlabs/ssh v0.3.4
@@ -45,7 +46,7 @@ require (
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/google/nftables v0.2.0
github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
@@ -55,12 +56,12 @@ require (
github.com/libp2p/go-netroute v0.2.1
github.com/magiconair/properties v1.8.7
github.com/mattn/go-sqlite3 v1.14.19
github.com/mdlayher/socket v0.4.1
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -89,10 +90,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.26.0
golang.org/x/net v0.30.0
golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0
golang.org/x/term v0.21.0
golang.org/x/sync v0.8.0
golang.org/x/term v0.25.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7
@@ -133,7 +134,6 @@ require (
github.com/containerd/containerd v1.7.16 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
@@ -219,7 +219,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect

36
go.sum
View File

@@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
@@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
@@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
@@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -871,8 +871,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@@ -873,7 +873,7 @@ services:
zitadel:
restart: 'always'
networks: [netbird]
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
image: 'ghcr.io/zitadel/zitadel:v2.54.10'
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
env_file:
- ./zitadel.env

View File

@@ -4,7 +4,6 @@ import (
"context"
"net"
"os"
"path/filepath"
"sync"
"testing"
"time"
@@ -58,7 +57,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite")
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -514,22 +513,3 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
}
func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) {
t.Helper()
dataDir := t.TempDir()
err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
if err != nil {
t.Fatal(err)
}
store, err := mgmt.NewSqliteStore(ctx, dataDir, nil)
if err != nil {
return nil, nil, err
}
return store, func() {
store.Close(ctx)
os.Remove(filepath.Join(dataDir, "store.db"))
}, nil
}

View File

@@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{}
_, err := util.ReadJson(mgmtConfigPath, loadedConfig)
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil {
return nil, err
}

File diff suppressed because it is too large Load Diff

View File

@@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
type test struct {
var (
publicDomain = "public.com"
privateDomain = "private.com"
unknownDomain = "unknown.com"
)
defaultInitAccount := initUserParams{
Domain: publicDomain,
UserId: "defaultUser",
}
initUnknown := defaultInitAccount
initUnknown.DomainCategory = UnknownCategory
initUnknown.Domain = unknownDomain
privateInitAccount := defaultInitAccount
privateInitAccount.Domain = privateDomain
privateInitAccount.DomainCategory = PrivateCategory
testCases := []struct {
name string
inputClaims jwtclaims.AuthorizationClaims
inputInitUserParams initUserParams
@@ -479,168 +498,143 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
expectedPrimaryDomainStatus bool
expectedCreatedBy string
expectedUsers []string
}
var (
publicDomain = "public.com"
privateDomain = "private.com"
unknownDomain = "unknown.com"
)
defaultInitAccount := initUserParams{
Domain: publicDomain,
UserId: "defaultUser",
}
testCase1 := test{
name: "New User With Public Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: publicDomain,
UserId: "pub-domain-user",
DomainCategory: PublicCategory,
}{
{
name: "New User With Public Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: publicDomain,
UserId: "pub-domain-user",
DomainCategory: PublicCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomainCategory: "",
expectedDomain: publicDomain,
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pub-domain-user",
expectedUsers: []string{"pub-domain-user"},
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomainCategory: "",
expectedDomain: publicDomain,
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pub-domain-user",
expectedUsers: []string{"pub-domain-user"},
}
initUnknown := defaultInitAccount
initUnknown.DomainCategory = UnknownCategory
initUnknown.Domain = unknownDomain
testCase2 := test{
name: "New User With Unknown Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: unknownDomain,
UserId: "unknown-domain-user",
DomainCategory: UnknownCategory,
{
name: "New User With Unknown Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: unknownDomain,
UserId: "unknown-domain-user",
DomainCategory: UnknownCategory,
},
inputInitUserParams: initUnknown,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomain: unknownDomain,
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "unknown-domain-user",
expectedUsers: []string{"unknown-domain-user"},
},
inputInitUserParams: initUnknown,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomain: unknownDomain,
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "unknown-domain-user",
expectedUsers: []string{"unknown-domain-user"},
}
testCase3 := test{
name: "New User With Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: privateDomain,
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
{
name: "New User With Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: privateDomain,
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
}
privateInitAccount := defaultInitAccount
privateInitAccount.Domain = privateDomain
privateInitAccount.DomainCategory = PrivateCategory
testCase4 := test{
name: "New Regular User With Existing Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: privateDomain,
UserId: "new-pvt-domain-user",
DomainCategory: PrivateCategory,
{
name: "New Regular User With Existing Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: privateDomain,
UserId: "new-pvt-domain-user",
DomainCategory: PrivateCategory,
},
inputUpdateAttrs: true,
inputInitUserParams: privateInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleUser,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
},
inputUpdateAttrs: true,
inputInitUserParams: privateInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleUser,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
}
testCase5 := test{
name: "Existing User With Existing Reclassified Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: PrivateCategory,
{
name: "Existing User With Existing Reclassified Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleOwner,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId},
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleOwner,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId},
}
testCase6 := test{
name: "Existing Account Id With Existing Reclassified Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: PrivateCategory,
{
name: "Existing Account Id With Existing Reclassified Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: PrivateCategory,
},
inputUpdateClaimAccount: true,
inputInitUserParams: defaultInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleOwner,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId},
},
inputUpdateClaimAccount: true,
inputInitUserParams: defaultInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleOwner,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId},
}
testCase7 := test{
name: "User With Private Category And Empty Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
{
name: "User With Private Category And Empty Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomain: "",
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomain: "",
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@@ -671,17 +665,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
initAccount := newAccountWithId(context.Background(), "", userId, domain)
_ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{
@@ -885,7 +878,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
}
}
func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@@ -894,7 +887,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
if err != nil {
t.Fatal(err)
}
@@ -903,14 +896,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return
}
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
}
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
_, err = manager.GetAccountIDByUserID(context.Background(), "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
t.Errorf("expected an error when user ID is empty")
}
}
@@ -1669,7 +1661,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
@@ -1684,7 +1676,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1696,7 +1688,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
})
require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1742,7 +1734,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1770,7 +1762,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1790,7 +1782,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1802,7 +1794,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
})
require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1850,7 +1842,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
@@ -1861,9 +1853,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
@@ -1968,6 +1957,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
}
}
func TestAccount_GetInactivePeers(t *testing.T) {
type test struct {
name string
peers map[string]*nbpeer.Peer
expectedPeers map[string]struct{}
}
testCases := []test{
{
name: "Peers with inactivity expiration disabled, no expired peers",
peers: map[string]*nbpeer.Peer{
"peer-1": {
InactivityExpirationEnabled: false,
},
"peer-2": {
InactivityExpirationEnabled: false,
},
},
expectedPeers: map[string]struct{}{},
},
{
name: "Two peers expired",
peers: map[string]*nbpeer.Peer{
"peer-1": {
ID: "peer-1",
InactivityExpirationEnabled: true,
Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC().Add(-45 * time.Second),
Connected: false,
LoginExpired: false,
},
LastLogin: time.Now().UTC().Add(-30 * time.Minute),
UserID: userID,
},
"peer-2": {
ID: "peer-2",
InactivityExpirationEnabled: true,
Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC().Add(-45 * time.Second),
Connected: false,
LoginExpired: false,
},
LastLogin: time.Now().UTC().Add(-2 * time.Hour),
UserID: userID,
},
"peer-3": {
ID: "peer-3",
InactivityExpirationEnabled: true,
Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(),
Connected: true,
LoginExpired: false,
},
LastLogin: time.Now().UTC().Add(-1 * time.Hour),
UserID: userID,
},
},
expectedPeers: map[string]struct{}{
"peer-1": {},
"peer-2": {},
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
Settings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
}
expiredPeers := account.GetInactivePeers()
assert.Len(t, expiredPeers, len(testCase.expectedPeers))
for _, peer := range expiredPeers {
if _, ok := testCase.expectedPeers[peer.ID]; !ok {
t.Fatalf("expected to have peer %s expired", peer.ID)
}
}
})
}
}
func TestAccount_GetPeersWithExpiration(t *testing.T) {
type test struct {
name string
@@ -2037,6 +2110,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
}
}
func TestAccount_GetPeersWithInactivity(t *testing.T) {
type test struct {
name string
peers map[string]*nbpeer.Peer
expectedPeers map[string]struct{}
}
testCases := []test{
{
name: "No account peers, no peers with expiration",
peers: map[string]*nbpeer.Peer{},
expectedPeers: map[string]struct{}{},
},
{
name: "Peers with login expiration disabled, no peers with expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
InactivityExpirationEnabled: false,
UserID: userID,
},
"peer-2": {
InactivityExpirationEnabled: false,
UserID: userID,
},
},
expectedPeers: map[string]struct{}{},
},
{
name: "Peers with login expiration enabled, return peers with expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
ID: "peer-1",
InactivityExpirationEnabled: true,
UserID: userID,
},
"peer-2": {
InactivityExpirationEnabled: false,
UserID: userID,
},
},
expectedPeers: map[string]struct{}{
"peer-1": {},
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
}
actual := account.GetPeersWithInactivity()
assert.Len(t, actual, len(testCase.expectedPeers))
if len(testCase.expectedPeers) > 0 {
for k := range testCase.expectedPeers {
contains := false
for _, peer := range actual {
if k == peer.ID {
contains = true
}
}
assert.True(t, contains)
}
}
})
}
}
func TestAccount_GetNextPeerExpiration(t *testing.T) {
type test struct {
name string
@@ -2198,9 +2340,175 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}
}
func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
type test struct {
name string
peers map[string]*nbpeer.Peer
expiration time.Duration
expirationEnabled bool
expectedNextRun bool
expectedNextExpiration time.Duration
}
expectedNextExpiration := time.Minute
testCases := []test{
{
name: "No peers, no expiration",
peers: map[string]*nbpeer.Peer{},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "No connected peers, no expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: false,
},
InactivityExpirationEnabled: false,
UserID: userID,
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: false,
},
InactivityExpirationEnabled: false,
UserID: userID,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "Connected peers with disabled expiration, no expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: true,
},
InactivityExpirationEnabled: false,
UserID: userID,
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: true,
},
InactivityExpirationEnabled: false,
UserID: userID,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "Expired peers, no expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: true,
},
InactivityExpirationEnabled: true,
UserID: userID,
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: true,
},
InactivityExpirationEnabled: true,
UserID: userID,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "To be expired peer, return expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: false,
LoginExpired: false,
LastSeen: time.Now().Add(-1 * time.Second),
},
InactivityExpirationEnabled: true,
LastLogin: time.Now().UTC(),
UserID: userID,
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: true,
},
InactivityExpirationEnabled: true,
UserID: userID,
},
},
expiration: time.Minute,
expirationEnabled: false,
expectedNextRun: true,
expectedNextExpiration: expectedNextExpiration,
},
{
name: "Peers added with setup keys, no expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: false,
},
InactivityExpirationEnabled: true,
SetupKey: "key",
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: false,
},
InactivityExpirationEnabled: true,
SetupKey: "key",
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
}
expiration, ok := account.GetNextInactivePeerExpiration()
assert.Equal(t, testCase.expectedNextRun, ok)
if testCase.expectedNextRun {
assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration)
} else {
assert.Equal(t, expiration, testCase.expectedNextExpiration)
}
})
}
}
func TestAccount_SetJWTGroups(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
// create a new account
account := &Account{
Id: "accountID",
Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
@@ -2211,62 +2519,120 @@ func TestAccount_SetJWTGroups(t *testing.T) {
Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
},
Settings: &Settings{GroupsPropagationEnabled: true},
Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
Users: map[string]*User{
"user1": {Id: "user1"},
"user2": {Id: "user2"},
"user1": {Id: "user1", AccountID: "accountID"},
"user2": {Id: "user2", AccountID: "accountID"},
},
}
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("empty jwt groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{})
assert.False(t, updated, "account should not be updated")
assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
})
t.Run("jwt match existing api group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1"})
assert.False(t, updated, "account should not be updated")
assert.Equal(t, 0, len(account.Users["user1"].AutoGroups))
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))
updated := account.SetJWTGroups("user1", []string{"group1"})
assert.False(t, updated, "account should not be updated")
assert.Equal(t, 1, len(account.Users["user1"].AutoGroups))
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
t.Run("add jwt group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1", "group2"})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Groups, 2, "new group should be added")
assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added")
assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
t.Run("existed group not update", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group2"})
assert.False(t, updated, "account should not be updated")
assert.Len(t, account.Groups, 2, "groups count should not be changed")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
t.Run("add new group", func(t *testing.T) {
updated := account.SetJWTGroups("user2", []string{"group1", "group3"})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Groups, 3, "new group should be added")
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
claims := jwtclaims.AuthorizationClaims{
UserId: "user2",
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
})
t.Run("remove all JWT groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
})
}
@@ -2366,7 +2732,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
func createStore(t TB) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}

View File

@@ -139,6 +139,13 @@ const (
PostureCheckUpdated Activity = 61
// PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted Activity = 62
PeerInactivityExpirationEnabled Activity = 63
PeerInactivityExpirationDisabled Activity = 64
AccountPeerInactivityExpirationEnabled Activity = 65
AccountPeerInactivityExpirationDisabled Activity = 66
AccountPeerInactivityExpirationDurationUpdated Activity = 67
)
var activityMap = map[Activity]Code{
@@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{
PostureCheckCreated: {"Posture check created", "posture.check.created"},
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"},
PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
}
// StringCode returns a string code of the activity

View File

@@ -210,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
func createDNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}

View File

@@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
account.Settings = &Settings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: DefaultPeerLoginExpiration,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
}
}

View File

@@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
}
if req.Settings.Extra != nil {

View File

@@ -54,6 +54,14 @@ components:
description: Period of time after which peer login expires (seconds).
type: integer
example: 43200
peer_inactivity_expiration_enabled:
description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
type: boolean
example: true
peer_inactivity_expiration:
description: Period of time of inactivity after which peer session expires (seconds).
type: integer
example: 43200
regular_users_view_blocked:
description: Allows blocking regular users from viewing parts of the system.
type: boolean
@@ -81,6 +89,8 @@ components:
required:
- peer_login_expiration_enabled
- peer_login_expiration
- peer_inactivity_expiration_enabled
- peer_inactivity_expiration
- regular_users_view_blocked
AccountExtraSettings:
type: object
@@ -243,6 +253,9 @@ components:
login_expiration_enabled:
type: boolean
example: false
inactivity_expiration_enabled:
type: boolean
example: false
approval_required:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
@@ -251,6 +264,7 @@ components:
- name
- ssh_enabled
- login_expiration_enabled
- inactivity_expiration_enabled
Peer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'
@@ -327,6 +341,10 @@ components:
type: string
format: date-time
example: "2023-05-05T09:00:35.477782Z"
inactivity_expiration_enabled:
description: Indicates whether peer inactivity expiration has been enabled or not
type: boolean
example: false
approval_required:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
@@ -354,6 +372,7 @@ components:
- last_seen
- login_expiration_enabled
- login_expired
- inactivity_expiration_enabled
- os
- ssh_enabled
- user_id

View File

@@ -220,6 +220,12 @@ type AccountSettings struct {
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
// PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"`
// PeerLoginExpiration Period of time after which peer login expires (seconds).
PeerLoginExpiration int `json:"peer_login_expiration"`
@@ -538,6 +544,9 @@ type Peer struct {
// Id Peer ID
Id string `json:"id"`
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
// Ip Peer's IP address
Ip string `json:"ip"`
@@ -613,6 +622,9 @@ type PeerBatch struct {
// Id Peer ID
Id string `json:"id"`
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
// Ip Peer's IP address
Ip string `json:"ip"`
@@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string
// PeerRequest defines model for PeerRequest.
type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
ApprovalRequired *bool `json:"approval_required,omitempty"`
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
}
// PersonalAccessToken defines model for PersonalAccessToken.

View File

@@ -7,6 +7,8 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -14,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
// PeersHandler is a handler that returns peers of the account
@@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
}
if req.ApprovalRequired != nil {
@@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
}
return &api.Peer{
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
}
}
@@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
}
}

View File

@@ -88,7 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir()
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -413,7 +413,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile)
store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -471,6 +471,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
}
func Test_SyncStatusRace(t *testing.T) {
t.Skip()
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
t.Skip("Skipping on CI and Postgres store")
}
@@ -482,9 +483,10 @@ func Test_SyncStatusRace(t *testing.T) {
}
func testSyncStatusRace(t *testing.T) {
t.Helper()
t.Skip()
dir := t.TempDir()
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -627,6 +629,7 @@ func testSyncStatusRace(t *testing.T) {
}
func Test_LoginPerformance(t *testing.T) {
t.Skip()
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("Skipping test on CI or Windows")
}
@@ -655,7 +658,7 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper()
dir := t.TempDir()
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",

View File

@@ -58,7 +58,7 @@ var _ = Describe("Management service", func() {
Expect(err).NotTo(HaveOccurred())
config.Datadir = dataDir
s, listener = startServer(config, dataDir, "testdata/store.sqlite")
s, listener = startServer(config, dataDir, "testdata/store.sql")
addr = listener.Addr().String()
client, conn = createRawClient(addr)
@@ -532,7 +532,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc.
Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer()
store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir)
store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir)
if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}

View File

@@ -27,7 +27,8 @@ type MockAccountManager struct {
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -58,7 +59,7 @@ type MockAccountManager struct {
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
@@ -194,14 +195,22 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
if am.GetAccountIDByUserOrAccountIdFunc != nil {
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
// AccountExists mock implementation of AccountExists from server.AccountManager interface
func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
if am.AccountExistsFunc != nil {
return am.AccountExistsFunc(ctx, accountID)
}
return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
}
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
if am.GetAccountIDByUserIdFunc != nil {
return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
}
return "", status.Errorf(
codes.Unimplemented,
"method GetAccountIDByUserOrAccountID is not implemented",
"method GetAccountIDByUserID is not implemented",
)
}
@@ -444,7 +453,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute)
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}

View File

@@ -773,7 +773,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
func createNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}

View File

@@ -110,6 +110,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return err
}
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
if err != nil {
return err
}
if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
}
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
}
return nil
}
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC()
@@ -138,25 +163,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil {
return err
return false, err
}
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
if oldStatus.LoginExpired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
}
return nil
return oldStatus.LoginExpired, nil
}
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -219,6 +234,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
event := activity.PeerInactivityExpirationEnabled
if !update.InactivityExpirationEnabled {
event = activity.PeerInactivityExpirationDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
}
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
@@ -442,23 +476,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
InactivityExpirationEnabled: addedByUser,
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
@@ -693,6 +728,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
updateRemotePeers := false
if login.UserID != "" {
if peer.UserID != login.UserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user")
}
changed, err := am.handleUserPeer(ctx, peer, settings)
if err != nil {
return nil, nil, nil, err

View File

@@ -38,6 +38,8 @@ type Peer struct {
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool
InactivityExpirationEnabled bool
// LastLogin the time when peer performed last login operation
LastLogin time.Time
// CreatedAt records the time the peer was created
@@ -187,6 +189,8 @@ func (p *Peer) Copy() *Peer {
CreatedAt: p.CreatedAt,
Ephemeral: p.Ephemeral,
Location: p.Location,
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
}
}
@@ -219,6 +223,22 @@ func (p *Peer) MarkLoginExpired(expired bool) {
p.Status = newStatus
}
// SessionExpired indicates whether the peer's session has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired.
// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired).
// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
// Only peers added by interactive SSO login can be expired.
func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) {
if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected {
return false, 0
}
expiresAt := p.Status.LastSeen.Add(expiresIn)
now := time.Now()
timeLeft := expiresAt.Sub(now)
return timeLeft <= 0, timeLeft
}
// LoginExpired indicates whether the peer's login has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).

View File

@@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) {
}
}
func TestPeer_SessionExpired(t *testing.T) {
tt := []struct {
name string
expirationEnabled bool
lastLogin time.Time
connected bool
expected bool
accountSettings *Settings
}{
{
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
expirationEnabled: false,
connected: false,
lastLogin: time.Now().UTC().Add(-1 * time.Second),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
},
expected: false,
},
{
name: "Peer Inactivity Should Expire",
expirationEnabled: true,
connected: false,
lastLogin: time.Now().UTC().Add(-1 * time.Second),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
expected: true,
},
{
name: "Peer Inactivity Should Not Expire",
expirationEnabled: true,
connected: true,
lastLogin: time.Now().UTC(),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
expected: false,
},
}
for _, c := range tt {
t.Run(c.name, func(t *testing.T) {
peerStatus := &nbpeer.PeerStatus{
Connected: c.connected,
}
peer := &nbpeer.Peer{
InactivityExpirationEnabled: c.expirationEnabled,
LastLogin: c.lastLogin,
Status: peerStatus,
UserID: userID,
}
expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration)
assert.Equal(t, expired, c.expected)
})
}
}
func TestAccountManager_GetNetworkMap(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@@ -1004,7 +1066,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -1069,7 +1131,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -1135,7 +1197,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -1188,6 +1250,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}

View File

@@ -1257,7 +1257,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
func createRouterStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}

View File

@@ -10,6 +10,7 @@ import (
"path/filepath"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
"time"
@@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr
if err != nil {
return nil, err
}
conns := runtime.NumCPU()
sql.SetMaxOpenConns(conns) // TODO: make it configurable
conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS"))
if err != nil {
conns = runtime.NumCPU()
}
sql.SetMaxOpenConns(conns)
log.Infof("Set max open db connections to %d", conns)
if err := migrate(ctx, db); err != nil {
return nil, fmt.Errorf("migrate: %w", err)
@@ -316,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
return nil
}
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
accountCopy := Account{
Domain: domain,
DomainCategory: category,
IsDomainPrimaryAccount: isPrimaryDomain,
}
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
result := s.db.WithContext(ctx).Model(&Account{}).
Select(fieldsToUpdate).
Where(idQueryCondition, accountID).
Updates(&accountCopy)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account %s", accountID)
}
return nil
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
@@ -378,15 +408,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error
}
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
groupsToSave := make([]nbgroup.Group, 0, len(groups))
for _, group := range groups {
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
// SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
if len(groups) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -420,7 +461,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -433,7 +474,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.NewSetupKeyNotFoundError()
return nil, status.NewSetupKeyNotFoundError(result.Error)
}
if key.AccountID == "" {
@@ -451,7 +492,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
return "", status.NewGetAccountFromStoreError(result.Error)
}
return token.ID, nil
@@ -465,7 +506,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
return nil, status.NewGetAccountFromStoreError(result.Error)
}
if token.UserID == "" {
@@ -500,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
var users []*User
result := s.db.Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting users from store")
}
return users, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID)
@@ -549,7 +604,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
return nil, status.NewGetAccountFromStoreError(result.Error)
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
@@ -612,7 +667,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
return nil, status.NewGetAccountFromStoreError(result.Error)
}
if user.AccountID == "" {
@@ -629,7 +684,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
return nil, status.NewGetAccountFromStoreError(result.Error)
}
if peer.AccountID == "" {
@@ -647,7 +702,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
return nil, status.NewGetAccountFromStoreError(result.Error)
}
if peer.AccountID == "" {
@@ -665,7 +720,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.Errorf(status.Internal, "issue getting account from store")
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -678,7 +733,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.Errorf(status.Internal, "issue getting account from store")
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -691,7 +746,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.NewSetupKeyNotFoundError()
return "", status.NewSetupKeyNotFoundError(result.Error)
}
if accountID == "" {
@@ -712,7 +767,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
}
// Convert the JSON strings to net.IP objects
@@ -740,7 +795,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
}
return labels, nil
@@ -753,7 +808,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
}
return accountNetwork.Network, nil
}
@@ -765,7 +820,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
return nil, status.Errorf(status.Internal, "issue getting peer from store")
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
}
return &peer, nil
@@ -777,7 +832,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
return nil, status.Errorf(status.Internal, "issue getting settings from store")
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
}
return accountSettings.Settings, nil
}
@@ -893,28 +948,6 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
return store, nil
}
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewPostgresqlStore(ctx, dsn, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(ctx, fileStore.InstallationID)
if err != nil {
return nil, err
}
for _, account := range fileStore.GetAllAccounts(ctx) {
err := store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
}
return store, nil
}
// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB.
func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewPostgresqlStore(ctx, dsn, metrics)
@@ -945,7 +978,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
return nil, status.NewSetupKeyNotFoundError(result.Error)
}
return &setupKey, nil
}
@@ -977,7 +1010,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
}
for _, existingPeerID := range group.Peers {
@@ -989,7 +1022,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
}
return nil
@@ -1003,7 +1036,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
}
for _, existingPeerID := range group.Peers {
@@ -1015,15 +1048,20 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
return status.Errorf(status.Internal, "issue updating group: %s", err)
}
return nil
}
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
return nil
@@ -1032,7 +1070,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
}
return nil
}
@@ -1116,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
// TODO: This fix is accepted for now, but if we need to handle this more frequently
// we may need to reconsider changing the types.
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
if s.storeEngine == PostgresStoreEngine {
query = query.Order("json_array_length(peers::json) DESC")
} else {
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
@@ -1127,6 +1173,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil
}
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)

View File

@@ -11,14 +11,13 @@ import (
"testing"
"time"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
@@ -31,7 +30,10 @@ func TestSqlite_NewStore(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@@ -39,15 +41,23 @@ func TestSqlite_NewStore(t *testing.T) {
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("skip large test on non-linux OS due to environment restrictions")
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
t.Run("SQLite", func(t *testing.T) {
store := newSqliteStore(t)
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
runLargeTest(t, store)
})
// create store outside to have a better time counter for the test
store := newPostgresqlStore(t)
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
t.Run("PostgreSQL", func(t *testing.T) {
runLargeTest(t, store)
})
@@ -199,7 +209,10 @@ func TestSqlite_SaveAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
@@ -213,7 +226,7 @@ func TestSqlite_SaveAccount(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
@@ -271,7 +284,10 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
testUserID := "testuser"
user := NewAdminUser(testUserID)
@@ -293,7 +309,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
}
account.Users[testUserID] = user
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 {
@@ -324,7 +340,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
for _, policy := range account.Policies {
var rules []*PolicyRule
err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
@@ -332,7 +348,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
for _, accountUser := range account.Users {
var pats []*PersonalAccessToken
err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
@@ -345,11 +361,10 @@ func TestSqlite_GetAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
if err != nil {
t.Fatal(err)
}
defer cleanup()
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -369,11 +384,10 @@ func TestSqlite_SavePeer(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
if err != nil {
t.Fatal(err)
}
defer cleanup()
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -421,11 +435,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
defer cleanup()
if err != nil {
t.Fatal(err)
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -478,11 +491,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
defer cleanup()
if err != nil {
t.Fatal(err)
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -532,11 +545,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
defer cleanup()
if err != nil {
t.Fatal(err)
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
@@ -555,11 +568,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
defer cleanup()
if err != nil {
t.Fatal(err)
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -579,11 +592,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
defer cleanup()
if err != nil {
t.Fatal(err)
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
@@ -598,13 +611,18 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
}
func TestMigrate(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
store := newSqliteStore(t)
// TODO: figure out why this fails on postgres
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
err := migrate(context.Background(), store.db)
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on empty db")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
@@ -640,7 +658,7 @@ func TestMigrate(t *testing.T) {
},
}
err = store.db.Save(act).Error
err = store.(*SqlStore).db.Save(act).Error
require.NoError(t, err, "Failed to insert Gob data")
type route struct {
@@ -656,16 +674,16 @@ func TestMigrate(t *testing.T) {
Route: route2.Route{ID: "route1"},
}
err = store.db.Save(rt).Error
err = store.(*SqlStore).db.Save(rt).Error
require.NoError(t, err, "Failed to insert Gob data")
err = migrate(context.Background(), store.db)
err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on gob populated db")
err = migrate(context.Background(), store.db)
err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on migrated db")
err = store.db.Delete(rt).Where("id = ?", "route1").Error
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
require.NoError(t, err, "Failed to delete Gob data")
prefix = netip.MustParsePrefix("12.0.0.0/24")
@@ -675,13 +693,13 @@ func TestMigrate(t *testing.T) {
Peer: "peer-id",
}
err = store.db.Save(nRT).Error
err = store.(*SqlStore).db.Save(nRT).Error
require.NoError(t, err, "Failed to insert json nil slice data")
err = migrate(context.Background(), store.db)
err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
err = migrate(context.Background(), store.db)
err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on migrated db")
}
@@ -716,63 +734,15 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(context.Background(), account)
}
func newPostgresqlStore(t *testing.T) *SqlStore {
t.Helper()
cleanUp, err := testutil.CreatePGDB()
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
}
store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil)
if err != nil {
t.Fatalf("could not initialize postgresql store: %s", err)
}
require.NoError(t, err)
require.NotNil(t, store)
return store
}
func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore {
t.Helper()
store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename)
t.Cleanup(cleanUpQ)
if err != nil {
return nil
}
cleanUpP, err := testutil.CreatePGDB()
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUpP)
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
}
pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil)
require.NoError(t, err)
require.NotNil(t, store)
return pstore
}
func TestPostgresql_NewStore(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
store := newPostgresqlStore(t)
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@@ -780,11 +750,14 @@ func TestPostgresql_NewStore(t *testing.T) {
}
func TestPostgresql_SaveAccount(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
store := newPostgresqlStore(t)
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
@@ -798,7 +771,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
@@ -852,11 +825,14 @@ func TestPostgresql_SaveAccount(t *testing.T) {
}
func TestPostgresql_DeleteAccount(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
store := newPostgresqlStore(t)
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
testUserID := "testuser"
user := NewAdminUser(testUserID)
@@ -878,7 +854,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
}
account.Users[testUserID] = user
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 {
@@ -909,7 +885,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
for _, policy := range account.Policies {
var rules []*PolicyRule
err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
@@ -917,7 +893,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
for _, accountUser := range account.Users {
var pats []*PersonalAccessToken
err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
@@ -926,11 +902,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
}
func TestPostgresql_SavePeerStatus(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -965,11 +944,14 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
}
func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
existingDomain := "test.com"
@@ -982,11 +964,14 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
}
func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -997,11 +982,14 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
}
func TestPostgresql_GetUserByTokenID(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -1011,11 +999,8 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
}
func TestSqlite_GetTakenIPs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
defer cleanup()
if err != nil {
t.Fatal(err)
@@ -1059,11 +1044,8 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
return
}
@@ -1104,11 +1086,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@@ -1130,10 +1109,8 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@@ -1152,11 +1129,8 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@@ -1185,3 +1159,108 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
group := &nbgroup.Group{
ID: "group-id",
AccountID: "account-id",
Name: "group-name",
Issued: "api",
Peers: nil,
}
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
if err != nil {
t.Fatal("failed to get group")
return err
}
t.Logf("group: %v", group)
return nil
})
assert.NoError(t, err)
}
func TestSqlite_GetAccoundUsers(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID)
require.NoError(t, err)
require.Len(t, users, len(account.Users))
}
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
t.Run("Should update attributes with public domain", func(t *testing.T) {
require.NoError(t, err)
domain := "example.com"
category := "public"
IsDomainPrimaryAccount := false
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, domain, account.Domain)
require.Equal(t, category, account.DomainCategory)
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
})
t.Run("Should update attributes with private domain", func(t *testing.T) {
require.NoError(t, err)
domain := "test.com"
category := "private"
IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, domain, account.Domain)
require.Equal(t, category, account.DomainCategory)
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
})
t.Run("Should fail when account does not exist", func(t *testing.T) {
require.NoError(t, err)
domain := "test.com"
category := "private"
IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
require.Error(t, err)
})
}
func TestSqlite_GetGroupByName(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
require.NoError(t, err)
require.Equal(t, "All", group.Name)
}

View File

@@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error {
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError() error {
return Errorf(NotFound, "setup key not found")
func NewSetupKeyNotFoundError(err error) error {
return Errorf(NotFound, "setup key not found: %s", err)
}
func NewGetAccountFromStoreError(err error) error {
return Errorf(Internal, "issue getting account from store: %s", err)
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store

View File

@@ -9,10 +9,12 @@ import (
"os"
"path"
"path/filepath"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/netbirdio/netbird/dns"
@@ -56,10 +58,13 @@ type Store interface {
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@@ -68,7 +73,8 @@ type Store interface {
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@@ -82,6 +88,7 @@ type Store interface {
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@@ -237,30 +244,41 @@ func getMigrations(ctx context.Context) []migrationFunc {
}
}
// NewTestStoreFromSqlite is only used in tests
func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
kind := getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
}
var store *SqlStore
var err error
var cleanUp func()
if filename == "" {
store, err = NewSqliteStore(ctx, dataDir, nil)
cleanUp = func() {
store.Close(ctx)
}
} else {
store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename)
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
if runtime.GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
storeStr = storeSqliteFileName
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
if err != nil {
return nil, nil, err
}
if filename != "" {
err = loadSQL(db, filename)
if err != nil {
return nil, nil, fmt.Errorf("failed to load SQL file: %v", err)
}
}
store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
}
cleanUp := func() {
store.Close(ctx)
}
if kind == PostgresStoreEngine {
cleanUp, err = testutil.CreatePGDB()
if err != nil {
@@ -281,21 +299,25 @@ func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string
return store, cleanUp, nil
}
func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) {
err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
func loadSQL(db *gorm.DB, filepath string) error {
sqlContent, err := os.ReadFile(filepath)
if err != nil {
return nil, nil, err
return err
}
store, err := NewSqliteStore(ctx, dataDir, nil)
if err != nil {
return nil, nil, err
queries := strings.Split(string(sqlContent), ";")
for _, query := range queries {
query = strings.TrimSpace(query)
if query != "" {
err := db.Exec(query).Error
if err != nil {
return err
}
}
}
return store, func() {
store.Close(ctx)
os.Remove(filepath.Join(dataDir, "store.db"))
}, nil
return nil
}
// MigrateFileStoreToSqlite migrates the file store to the SQLite store.

View File

@@ -0,0 +1,37 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0);
INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,'');
INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
INSERT INTO installations VALUES(1,'');

Binary file not shown.

33
management/server/testdata/store.sql vendored Normal file
View File

@@ -0,0 +1,33 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
INSERT INTO installations VALUES(1,'');

Binary file not shown.

View File

@@ -0,0 +1,35 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,'');
INSERT INTO installations VALUES(1,'');

View File

@@ -0,0 +1,35 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO installations VALUES(1,'');

39
management/server/testdata/storev1.sql vendored Normal file
View File

@@ -0,0 +1,39 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('auth0|61bf82ddeab084006aa1bccd','','2024-10-02 17:00:54.181873+02:00','','',0,'a443c07a-5765-4a78-97fc-390d9c1d0e49','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('google-oauth2|103201118415301331038','','2024-10-02 17:00:54.225803+02:00','','',0,'b6d0b152-364e-40c1-a8a1-fa7bcac2267f','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','One-off key','one-off','2021-12-24 16:09:45.926075752+01:00','2022-01-23 16:09:45.926075752+01:00','2021-12-24 16:09:45.926075752+01:00',0,1,'2021-12-24 16:12:45.763424077+01:00','[]',0,0);
INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0);
INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0);
INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0);
INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
INSERT INTO installations VALUES(1,'');

Binary file not shown.

View File

@@ -8,21 +8,22 @@ import (
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
const (
UserRoleOwner UserRole = "owner"
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown"
UserRoleOwner UserRole = "owner"
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown"
UserRoleBillingAdmin UserRole = "billing_admin"
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
@@ -41,6 +42,8 @@ func StrRoleToUserRole(strRole string) UserRole {
return UserRoleAdmin
case "user":
return UserRoleUser
case "billing_admin":
return UserRoleBillingAdmin
default:
return UserRoleUnknown
}
@@ -1254,6 +1257,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
}
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}
userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}
for _, gid := range groupsToAdd {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
for _, gid := range groupsToRemove {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
return groupsToUpdate, nil
}
// addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
}
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
// removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
// skip removing peers from group All
if group.Name == "All" {
return
}
updatedPeers := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
if _, found := userPeerIDs[pid]; !found {
updatedPeers = append(updatedPeers, pid)
}
}
group.Peers = updatedPeers
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {

View File

@@ -813,10 +813,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err)
}
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
assert.NoError(t, err)
acc, err := am.Store.GetAccount(context.Background(), accID)
acc, err := am.Store.GetAccount(context.Background(), account.Id)
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {

View File

@@ -16,8 +16,10 @@ const (
type Metrics struct {
metric.Meter
TransferBytesSent metric.Int64Counter
TransferBytesRecv metric.Int64Counter
TransferBytesSent metric.Int64Counter
TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram
peers metric.Int64UpDownCounter
peerActivityChan chan string
@@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err
}
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
m := &Metrics{
Meter: meter,
TransferBytesSent: bytesSent,
TransferBytesRecv: bytesRecv,
peers: peers,
Meter: meter,
TransferBytesSent: bytesSent,
TransferBytesRecv: bytesRecv,
AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime,
peers: peers,
ctx: ctx,
peerActivityChan: make(chan string, 10),
@@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
m.peerLastActive[id] = time.Time{}
}
// RecordAuthenticationTime measures the time taken for peer authentication
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
}
// RecordPeerStoreTime measures the time to store the peer in map
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
}
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
func (m *Metrics) PeerDisconnected(id string) {
m.peers.Add(m.ctx, -1)
@@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
}
}
}
func getStandardBucketBoundaries() []float64 {
return []float64{
0.1,
0.5,
1,
5,
10,
50,
100,
500,
1000,
5000,
10000,
}
}

153
relay/server/handshake.go Normal file
View File

@@ -0,0 +1,153 @@
package server
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
)
// preparedMsg contains the marshalled success response messages
type preparedMsg struct {
responseHelloMsg []byte
responseAuthMsg []byte
}
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
rhm, err := marshalResponseHelloMsg(instanceURL)
if err != nil {
return nil, err
}
ram, err := messages.MarshalAuthResponse(instanceURL)
if err != nil {
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
}
return &preparedMsg{
responseHelloMsg: rhm,
responseAuthMsg: ram,
}, nil
}
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
addr := &address.Address{URL: instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal response address: %w", err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
}
return responseMsg, nil
}
type handshake struct {
conn net.Conn
validator auth.Validator
preparedMsg *preparedMsg
handshakeMethodAuth bool
peerID string
}
func (h *handshake) handshakeReceive() ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
}
var (
bytePeerID []byte
peerID string
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
case messages.MsgTypeAuth:
h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
}
if err != nil {
return nil, err
}
h.peerID = peerID
return bytePeerID, nil
}
func (h *handshake) handshakeResponse() error {
var responseMsg []byte
if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg
} else {
responseMsg = h.preparedMsg.responseHelloMsg
}
if _, err := h.conn.Write(responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
}
return nil
}
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
}
return rawPeerID, peerID, nil
}
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := h.validator.Validate(authPayload); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
}
return rawPeerID, peerID, nil
}

View File

@@ -7,16 +7,13 @@ import (
"net/url"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics"
)
@@ -28,6 +25,7 @@ type Relay struct {
store *Store
instanceURL string
preparedMsg *preparedMsg
closed bool
closeMu sync.RWMutex
@@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
return nil, fmt.Errorf("get instance URL: %v", err)
}
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("prepare message: %v", err)
}
return r, nil
}
@@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
// Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) {
acceptTime := time.Now()
r.closeMu.RLock()
defer r.closeMu.RUnlock()
if r.closed {
return
}
peerID, err := r.handshake(conn)
h := handshake{
conn: conn,
validator: r.validator,
preparedMsg: r.preparedMsg,
}
peerID, err := h.handshakeReceive()
if err != nil {
log.Errorf("failed to handshake: %s", err)
cErr := conn.Close()
if cErr != nil {
if cErr := conn.Close(); cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
@@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, peerID, conn, r.store)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now()
r.store.AddPeer(peer)
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String())
go func() {
peer.Work()
@@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
}()
if err := h.handshakeResponse(); err != nil {
log.Errorf("failed to send handshake response, close peer: %s", err)
peer.Close()
}
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
}
// Shutdown closes the relay server
@@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
func (r *Relay) InstanceURL() string {
return r.instanceURL
}
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
var (
responseMsg []byte
peerID []byte
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
if err != nil {
return nil, err
}
_, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
return rawPeerID, responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := r.validator.Validate(authPayload); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return rawPeerID, responseMsg, nil
}

View File

@@ -6,6 +6,7 @@ import (
"io"
"time"
"github.com/netbirdio/signal-dispatcher/dispatcher"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
@@ -13,8 +14,6 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/netbirdio/signal-dispatcher/dispatcher"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
"github.com/netbirdio/netbird/signal/proto"

View File

@@ -1,11 +1,15 @@
package util
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"text/template"
log "github.com/sirupsen/logrus"
)
@@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) {
return res, nil
}
// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution
func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) {
envVars := getEnvMap()
f, err := os.Open(file)
if err != nil {
return nil, err
}
defer f.Close()
bs, err := io.ReadAll(f)
if err != nil {
return nil, err
}
t, err := template.New("").Parse(string(bs))
if err != nil {
return nil, fmt.Errorf("error parsing template: %v", err)
}
var output bytes.Buffer
// Execute the template, substituting environment variables
err = t.Execute(&output, envVars)
if err != nil {
return nil, fmt.Errorf("error executing template: %v", err)
}
err = json.Unmarshal(output.Bytes(), &res)
if err != nil {
return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err)
}
return res, nil
}
// getEnvMap Convert the output of os.Environ() to a map
func getEnvMap() map[string]string {
envMap := make(map[string]string)
for _, env := range os.Environ() {
parts := strings.SplitN(env, "=", 2)
if len(parts) == 2 {
envMap[parts[0]] = parts[1]
}
}
return envMap
}
// CopyFileContents copies contents of the given src file to the dst file
func CopyFileContents(src, dst string) (err error) {
in, err := os.Open(src)

126
util/file_suite_test.go Normal file
View File

@@ -0,0 +1,126 @@
package util_test
import (
"crypto/md5"
"encoding/hex"
"io"
"os"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/netbirdio/netbird/util"
)
var _ = Describe("Client", func() {
var (
tmpDir string
)
type TestConfig struct {
SomeMap map[string]string
SomeArray []string
SomeField int
}
BeforeEach(func() {
var err error
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
err := os.RemoveAll(tmpDir)
Expect(err).NotTo(HaveOccurred())
})
Describe("Config", func() {
Context("in JSON format", func() {
It("should be written and read successfully", func() {
m := make(map[string]string)
m["key1"] = "value1"
m["key2"] = "value2"
arr := []string{"value1", "value2"}
written := &TestConfig{
SomeMap: m,
SomeArray: arr,
SomeField: 99,
}
err := util.WriteJson(tmpDir+"/testconfig.json", written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
})
})
})
Describe("Copying file contents", func() {
Context("from one file to another", func() {
It("should be successful", func() {
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := util.WriteJson(src, []string{"1", "2", "3"})
Expect(err).NotTo(HaveOccurred())
err = util.CopyFileContents(src, dst)
Expect(err).NotTo(HaveOccurred())
hashSrc := md5.New()
hashDst := md5.New()
srcFile, err := os.Open(src)
Expect(err).NotTo(HaveOccurred())
dstFile, err := os.Open(dst)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashSrc, srcFile)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashDst, dstFile)
Expect(err).NotTo(HaveOccurred())
err = srcFile.Close()
Expect(err).NotTo(HaveOccurred())
err = dstFile.Close()
Expect(err).NotTo(HaveOccurred())
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
})
})
})
Describe("Handle config file without full path", func() {
Context("config file handling", func() {
It("should be successful", func() {
written := &TestConfig{
SomeField: 123,
}
cfgFile := "test_cfg.json"
defer os.Remove(cfgFile)
err := util.WriteJson(cfgFile, written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(cfgFile, &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
})
})
})
})

View File

@@ -1,126 +1,198 @@
package util_test
package util
import (
"crypto/md5"
"encoding/hex"
"io"
"os"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/netbirdio/netbird/util"
"reflect"
"strings"
"testing"
)
var _ = Describe("Client", func() {
var (
tmpDir string
)
type TestConfig struct {
SomeMap map[string]string
SomeArray []string
SomeField int
func TestReadJsonWithEnvSub(t *testing.T) {
type Config struct {
CertFile string `json:"CertFile"`
Credentials string `json:"Credentials"`
NestedOption struct {
URL string `json:"URL"`
} `json:"NestedOption"`
}
BeforeEach(func() {
var err error
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
Expect(err).NotTo(HaveOccurred())
})
type testCase struct {
name string
envVars map[string]string
jsonTemplate string
expectedResult Config
expectError bool
errorContains string
}
AfterEach(func() {
err := os.RemoveAll(tmpDir)
Expect(err).NotTo(HaveOccurred())
})
tests := []testCase{
{
name: "All environment variables set",
envVars: map[string]string{
"CERT_FILE": "/etc/certs/env_cert.crt",
"CREDENTIALS": "env_credentials",
"URL": "https://env.testing.com",
},
jsonTemplate: `{
"CertFile": "{{ .CERT_FILE }}",
"Credentials": "{{ .CREDENTIALS }}",
"NestedOption": {
"URL": "{{ .URL }}"
}
}`,
expectedResult: Config{
CertFile: "/etc/certs/env_cert.crt",
Credentials: "env_credentials",
NestedOption: struct {
URL string `json:"URL"`
}{
URL: "https://env.testing.com",
},
},
expectError: false,
},
{
name: "Missing environment variable",
envVars: map[string]string{
"CERT_FILE": "/etc/certs/env_cert.crt",
"CREDENTIALS": "env_credentials",
// "URL" is intentionally missing
},
jsonTemplate: `{
"CertFile": "{{ .CERT_FILE }}",
"Credentials": "{{ .CREDENTIALS }}",
"NestedOption": {
"URL": "{{ .URL }}"
}
}`,
expectedResult: Config{
CertFile: "/etc/certs/env_cert.crt",
Credentials: "env_credentials",
NestedOption: struct {
URL string `json:"URL"`
}{
URL: "<no value>",
},
},
expectError: false,
},
{
name: "Invalid JSON template",
envVars: map[string]string{
"CERT_FILE": "/etc/certs/env_cert.crt",
"CREDENTIALS": "env_credentials",
"URL": "https://env.testing.com",
},
jsonTemplate: `{
"CertFile": "{{ .CERT_FILE }}",
"Credentials": "{{ .CREDENTIALS }",
"NestedOption": {
"URL": "{{ .URL }}"
}
}`, // Note the missing closing brace in "{{ .CREDENTIALS }"
expectedResult: Config{},
expectError: true,
errorContains: "unexpected \"}\" in operand",
},
{
name: "No substitutions",
envVars: map[string]string{
"CERT_FILE": "/etc/certs/env_cert.crt",
"CREDENTIALS": "env_credentials",
"URL": "https://env.testing.com",
},
jsonTemplate: `{
"CertFile": "/etc/certs/cert.crt",
"Credentials": "admnlknflkdasdf",
"NestedOption" : {
"URL": "https://testing.com"
}
}`,
expectedResult: Config{
CertFile: "/etc/certs/cert.crt",
Credentials: "admnlknflkdasdf",
NestedOption: struct {
URL string `json:"URL"`
}{
URL: "https://testing.com",
},
},
expectError: false,
},
{
name: "Should fail when Invalid characters in variables",
envVars: map[string]string{
"CERT_FILE": `"/etc/certs/"cert".crt"`,
"CREDENTIALS": `env_credentia{ls}`,
"URL": `https://env.testing.com?param={{value}}`,
},
jsonTemplate: `{
"CertFile": "{{ .CERT_FILE }}",
"Credentials": "{{ .CREDENTIALS }}",
"NestedOption": {
"URL": "{{ .URL }}"
}
}`,
expectedResult: Config{
CertFile: `"/etc/certs/"cert".crt"`,
Credentials: `env_credentia{ls}`,
NestedOption: struct {
URL string `json:"URL"`
}{
URL: `https://env.testing.com?param={{value}}`,
},
},
expectError: true,
},
}
Describe("Config", func() {
Context("in JSON format", func() {
It("should be written and read successfully", func() {
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
for key, value := range tc.envVars {
t.Setenv(key, value)
}
m := make(map[string]string)
m["key1"] = "value1"
m["key2"] = "value2"
tempFile, err := os.CreateTemp("", "config*.json")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
arr := []string{"value1", "value2"}
written := &TestConfig{
SomeMap: m,
SomeArray: arr,
SomeField: 99,
defer func() {
err = os.Remove(tempFile.Name())
if err != nil {
t.Logf("Failed to remove temp file: %v", err)
}
}()
err := util.WriteJson(tmpDir+"/testconfig.json", written)
Expect(err).NotTo(HaveOccurred())
_, err = tempFile.WriteString(tc.jsonTemplate)
if err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
err = tempFile.Close()
if err != nil {
t.Fatalf("Failed to close temp file: %v", err)
}
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
var result Config
})
})
})
_, err = ReadJsonWithEnvSub(tempFile.Name(), &result)
Describe("Copying file contents", func() {
Context("from one file to another", func() {
It("should be successful", func() {
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := util.WriteJson(src, []string{"1", "2", "3"})
Expect(err).NotTo(HaveOccurred())
err = util.CopyFileContents(src, dst)
Expect(err).NotTo(HaveOccurred())
hashSrc := md5.New()
hashDst := md5.New()
srcFile, err := os.Open(src)
Expect(err).NotTo(HaveOccurred())
dstFile, err := os.Open(dst)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashSrc, srcFile)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashDst, dstFile)
Expect(err).NotTo(HaveOccurred())
err = srcFile.Close()
Expect(err).NotTo(HaveOccurred())
err = dstFile.Close()
Expect(err).NotTo(HaveOccurred())
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
})
})
})
Describe("Handle config file without full path", func() {
Context("config file handling", func() {
It("should be successful", func() {
written := &TestConfig{
SomeField: 123,
if tc.expectError {
if err == nil {
t.Fatalf("Expected error but got none")
}
cfgFile := "test_cfg.json"
defer os.Remove(cfgFile)
err := util.WriteJson(cfgFile, written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(cfgFile, &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
})
if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err)
}
} else {
if err != nil {
t.Fatalf("ReadJsonWithEnvSub failed: %v", err)
}
if !reflect.DeepEqual(result, tc.expectedResult) {
t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult)
}
}
})
})
})
}
}

View File

@@ -11,7 +11,8 @@ import (
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
NetbirdFwmark = 0x1BD00
PreroutingFwmark = 0x1BD01
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)