Compare commits

...

22 Commits

Author SHA1 Message Date
bcmmbaga
72513d7522 Skip network map calculation when client serial matches current
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-01-27 23:03:43 +03:00
Zoltan Papp
a1f1bf1f19 Merge branch 'main' into feat/network-map-serial 2025-12-18 15:59:53 +01:00
Zoltan Papp
b5dec3df39 Track network serial in engine 2025-12-18 15:27:49 +01:00
Zoltan Papp
447cd287f5 [ci] Add local lint setup with pre-push hook to catch issues early (#4925)
* Add local lint setup with pre-push hook to catch issues early

Developers can now catch lint issues before pushing, reducing CI failures
and iteration time. The setup uses golangci-lint locally with the same
configuration as CI.

Setup:
- Run `make setup-hooks` once after cloning
- Pre-push hook automatically lints changed files (~90s)
- Use `make lint` to manually check changed files
- Use `make lint-all` to run full CI-equivalent lint

The Makefile auto-installs golangci-lint to ./bin/ using go install to
match the Go version in go.mod, avoiding version compatibility issues.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2025-12-15 10:34:48 +01:00
Zoltan Papp
5748bdd64e Add health-check agent recognition to avoid error logs (#4917)
Health-check connections now send a properly formatted auth message
with a well-known peer ID instead of immediately closing. The server
recognizes this peer ID and handles the connection gracefully with a
debug log instead of error logs.
2025-12-15 10:28:25 +01:00
Diego Romar
08f31fbcb3 [iOS] Add force relay connection on iOS (#4928)
* [ios] Add a bogus test to check iOS behavior when setting environment variables

* [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables"

This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b.

* [ios] Add EnvList struct to export and import environment variables

* [ios] Add envList parameter to the iOS Client Run method

* [ios] Add some debug logging to exportEnvVarList

* Add "//go:build ios" to client/ios/NetBirdSDK files
2025-12-12 14:29:58 -03:00
Bethuel Mmbaga
932c02eaab [management] Approve all pending peers when peer approval is disabled (#4806) 2025-12-12 18:49:57 +03:00
Pascal Fischer
abcbde26f9 [management] remove context from store methods (#4940) 2025-12-11 21:45:47 +01:00
Pascal Fischer
90e3b8009f [management] Fix sync metrics (#4939) 2025-12-11 20:11:12 +01:00
Pascal Fischer
94d34dc0c5 [management] monitoring updates (#4937) 2025-12-11 18:29:15 +01:00
Pascal Fischer
44851e06fb [management] cleanup logs (#4933) 2025-12-10 19:26:51 +01:00
Viktor Liu
3f4f825ec1 [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) 2025-12-05 17:42:49 +01:00
Viktor Liu
f538e6e9ae [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) 2025-12-05 03:29:27 +01:00
Maycon Santos
cb6b086164 [client] Reorder subsystem shutdown so peer removal goes first (#4914)
Remove peers before DNS and routes
2025-12-04 21:01:22 +01:00
Zoltan Papp
71b6855e09 [client] Fix engine shutdown deadlock and sync-signal message handling races (#4891)
* Fix engine shutdown deadlock and message handling races

- Release syncMsgMux before waiting for shutdownWg to prevent deadlock
- Check context inside lock in handleSync and receiveSignalEvents
- Prevents nil pointer access when messages arrive during engine stop
2025-12-04 19:51:50 +01:00
Viktor Liu
9bdc4908fb [client] Passthrough all non-NetBird chains to prevent them from dropping NetBird traffic (#4899) 2025-12-04 19:16:38 +01:00
Bethuel Mmbaga
031ab11178 [client] Remove select account prompt (#4912)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-12-04 14:57:29 +01:00
Hakan Sariman
20f5f00635 [client] Add unit tests for engine synchronization and Info flag copying
- Introduced tests for the Engine's handleSync method to verify behavior when SkipNetworkMapUpdate is true and when NetworkMap is nil.
- Added a test for the Info struct to ensure correct copying of flag values from one instance to another, while preserving unrelated fields.
2025-10-17 10:03:07 +03:00
Hakan Sariman
fc141cf3a3 [client] Refactor lastNetworkMapSerial handling in GrpcClient
- Removed atomic operations for lastNetworkMapSerial and replaced them with mutex-based methods for thread-safe access.
2025-09-29 18:49:23 +07:00
Hakan Sariman
d0c65fa08e [client] Add skipNetworkMapUpdate field to SyncResponse for conditional updates 2025-09-29 18:28:14 +07:00
Hakan Sariman
f241bfa339 Refactor flag setting in Info struct to use CopyFlagsFrom method 2025-09-29 15:38:35 +07:00
Hakan Sariman
4b2cd97d5f [client] Enhance SyncRequest with NetworkMap serial tracking
- Added `networkMapSerial` field to `SyncRequest` for tracking the last known network map serial number.
- Updated `GrpcClient` to store and utilize the last network map serial during sync operations, optimizing synchronization processes.
- Improved handling of system info updates to ensure accurate metadata is sent with sync requests.
2025-09-25 19:28:35 +07:00
63 changed files with 1027 additions and 339 deletions

11
.githooks/pre-push Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
echo "Running pre-push hook..."
if ! make lint; then
echo ""
echo "Hint: To push without verification, run:"
echo " git push --no-verify"
exit 1
fi
echo "All checks passed!"

View File

@@ -136,6 +136,14 @@ checked out and set up:
go mod tidy go mod tidy
``` ```
6. Configure Git hooks for automatic linting:
```bash
make setup-hooks
```
This will configure Git to run linting automatically before each push, helping catch issues early.
### Dev Container Support ### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers. If you prefer using a dev container for development, NetBird now includes support for dev containers.

27
Makefile Normal file
View File

@@ -0,0 +1,27 @@
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)
@echo "Running lint on changed files..."
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
# Lint entire codebase (slow, matches CI)
lint-all: $(GOLANGCI_LINT)
@echo "Running lint on all files..."
@$(GOLANGCI_LINT) run --timeout=12m
# Just install the linter
lint-install: $(GOLANGCI_LINT)
# Setup git hooks for all developers
setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"

View File

@@ -27,7 +27,11 @@ import (
) )
const ( const (
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING" chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingNat = "netbird-rt-postrouting"
@@ -91,7 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error var err error
r.filterTable, err = r.loadFilterTable() r.filterTable, err = r.loadFilterTable()
if err != nil { if err != nil {
log.Warnf("failed to load filter table, skipping accept rules: %v", err) log.Debugf("ip filter table not found: %v", err)
} }
return r, nil return r, nil
@@ -183,6 +187,33 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
return nil, errFilterTableNotFound return nil, errFilterTableNotFound
} }
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *router) createContainers() error { func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw, Name: chainNameRoutingFw,
@@ -930,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface. // This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error { func (r *router) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) acceptFilterTableRules() error {
if r.filterTable == nil { if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil return nil
} }
@@ -944,11 +988,11 @@ func (r *router) acceptForwardRules() error {
// Try iptables first and fallback to nftables if iptables is not available // Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New() ipt, err := iptables.New()
if err != nil { if err != nil {
// filter table exists but iptables is not // iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables" fw = "nftables"
return r.acceptFilterRulesNftables() return r.acceptFilterRulesNftables(r.filterTable)
} }
return r.acceptFilterRulesIptables(ipt) return r.acceptFilterRulesIptables(ipt)
@@ -959,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() { for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else { } else {
log.Debugf("added iptables forward rule: %v", rule) log.Debugf("added iptables forward rule: %v", rule)
} }
@@ -967,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
inputRule := r.getAcceptInputRule() inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else { } else {
log.Debugf("added iptables input rule: %v", inputRule) log.Debugf("added iptables input rule: %v", inputRule)
} }
@@ -987,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
} }
func (r *router) acceptFilterRulesNftables() error { // acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name()) intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *router) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
iifRule := &nftables.Rule{ iifRule := &nftables.Rule{
Table: r.filterTable, Table: chain.Table,
Chain: &nftables.Chain{ Chain: chain,
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -1021,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
Data: intf, Data: intf,
}, },
} }
oifRule := &nftables.Rule{ oifRule := &nftables.Rule{
Table: r.filterTable, Table: chain.Table,
Chain: &nftables.Chain{ Chain: chain,
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: append(oifExprs, getEstablishedExprs(2)...), Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif), UserData: []byte(userDataAcceptForwardRuleOif),
} }
r.conn.InsertRule(oifRule) r.conn.InsertRule(oifRule)
}
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
inputRule := &nftables.Rule{ inputRule := &nftables.Rule{
Table: r.filterTable, Table: chain.Table,
Chain: &nftables.Chain{ Chain: chain,
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -1058,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
UserData: []byte(userDataAcceptInputRule), UserData: []byte(userDataAcceptInputRule),
} }
r.conn.InsertRule(inputRule) r.conn.InsertRule(inputRule)
return r.conn.Flush()
} }
func (r *router) removeAcceptFilterRules() error { func (r *router) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeFilterTableRules() error {
if r.filterTable == nil { if r.filterTable == nil {
return nil return nil
} }
ipt, err := iptables.New() ipt, err := iptables.New()
if err != nil { if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptFilterRulesNftables() return r.removeAcceptRulesFromTable(r.filterTable)
} }
return r.removeAcceptFilterRulesIptables(ipt) return r.removeAcceptFilterRulesIptables(ipt)
} }
func (r *router) removeAcceptFilterRulesNftables() error { func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil { if err != nil {
return fmt.Errorf("list chains: %v", err) return fmt.Errorf("list chains: %v", err)
} }
for _, chain := range chains { for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name { if chain.Table.Name != table.Name {
continue continue
} }
@@ -1091,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
continue continue
} }
rules, err := r.conn.GetRules(r.filterTable, chain) if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *router) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil { if err != nil {
return fmt.Errorf("get rules: %v", err) log.Debugf("list chains for family %d: %v", family, err)
continue
} }
for _, rule := range rules { for _, chain := range allChains {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || if r.isExternalChain(chain) {
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || chains = append(chains, chain)
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
} }
} }
} }
if err := r.conn.Flush(); err != nil { return chains
return fmt.Errorf(flushError, err) }
func (r *router) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
} }
return nil // Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
} }
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1119,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() { for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
} }
} }
inputRule := r.getAcceptInputRule() inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)

View File

@@ -107,10 +107,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
if !p.providerConfig.DisablePromptLogin { if !p.providerConfig.DisablePromptLogin {
switch p.providerConfig.LoginFlag { switch p.providerConfig.LoginFlag {
case common.LoginFlagPromptLogin: case common.LoginFlagPromptLogin:
params = append(params, oauth2.SetAuthURLParam("prompt", "login select_account")) params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
case common.LoginFlagMaxAge0: case common.LoginFlagMaxAge0:
params = append(params, oauth2.SetAuthURLParam("max_age", "0")) params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
params = append(params, oauth2.SetAuthURLParam("prompt", "select_account"))
} }
} }
if p.providerConfig.LoginHint != "" { if p.providerConfig.LoginHint != "" {

View File

@@ -15,9 +15,8 @@ import (
func TestPromptLogin(t *testing.T) { func TestPromptLogin(t *testing.T) {
const ( const (
promptSelectAccountLogin = "prompt=login+select_account" promptLogin = "prompt=login"
promptSelectAccount = "prompt=select_account" maxAge0 = "max_age=0"
maxAge0 = "max_age=0"
) )
tt := []struct { tt := []struct {
@@ -27,14 +26,14 @@ func TestPromptLogin(t *testing.T) {
expectContains []string expectContains []string
}{ }{
{ {
name: "Prompt login with select account", name: "Prompt login",
loginFlag: mgm.LoginFlagPromptLogin, loginFlag: mgm.LoginFlagPromptLogin,
expectContains: []string{promptSelectAccountLogin}, expectContains: []string{promptLogin},
}, },
{ {
name: "Max age 0 with select account", name: "Max age 0",
loginFlag: mgm.LoginFlagMaxAge0, loginFlag: mgm.LoginFlagMaxAge0,
expectContains: []string{maxAge0, promptSelectAccount}, expectContains: []string{maxAge0},
}, },
{ {
name: "Disable prompt login", name: "Disable prompt login",

View File

@@ -273,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetSyncResponsePersistence(c.persistSyncResponse) engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
@@ -293,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
<-engineCtx.Done() <-engineCtx.Done()
c.engineMutex.Lock() c.engineMutex.Lock()
engine := c.engine
c.engine = nil c.engine = nil
c.engineMutex.Unlock() c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil { // todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil { if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err) log.Errorf("Failed to stop engine: %v", err)
} }

View File

@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil return nil
} }
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips) f.cache.set(domain, question.Qtype, ips)

View File

@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
return nil return nil
} }
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.connMgr != nil { if e.connMgr != nil {
e.connMgr.Close() e.connMgr.Close()
@@ -292,21 +291,12 @@ func (e *Engine) Stop() error {
} }
log.Info("Network monitor: stopped") log.Info("Network monitor: stopped")
if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" {
log.Info("removing peers before dns")
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
}
}
if err := e.stopSSHServer(); err != nil { if err := e.stopSSHServer(); err != nil {
log.Warnf("failed to stop SSH server: %v", err) log.Warnf("failed to stop SSH server: %v", err)
} }
e.cleanupSSHConfig() e.cleanupSSHConfig()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.ingressGatewayMgr != nil { if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil { if err := e.ingressGatewayMgr.Close(); err != nil {
log.Warnf("failed to cleanup forward rules: %v", err) log.Warnf("failed to cleanup forward rules: %v", err)
@@ -314,33 +304,28 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil e.ingressGatewayMgr = nil
} }
e.stopDNSForwarder() if e.srWatcher != nil {
e.srWatcher.Close()
}
if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" { log.Info("cleaning up status recorder states")
log.Info("removing peers before routes") e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
if err := e.removeAllPeers(); err != nil { e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
return fmt.Errorf("failed to remove all peers: %s", err) e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
}
if err := e.removeAllPeers(); err != nil {
log.Errorf("failed to remove all peers: %s", err)
} }
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop(e.stateManager) e.routeManager.Stop(e.stateManager)
} }
if e.srWatcher != nil { e.stopDNSForwarder()
e.srWatcher.Close()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { // stop/restore DNS after peers are closed but before interface goes down
log.Info("removing peers after dns and routes") // so dbus and friends don't complain because of a missing interface
if err := e.removeAllPeers(); err != nil { e.stopDNSServer()
return fmt.Errorf("failed to remove all peers: %s", err)
}
}
if e.cancel != nil { if e.cancel != nil {
e.cancel() e.cancel()
@@ -353,16 +338,18 @@ func (e *Engine) Stop() error {
e.flowManager.Close() e.flowManager.Close()
} }
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer stateCancel()
if err := e.stateManager.Stop(ctx); err != nil { if err := e.stateManager.Stop(stateCtx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err) log.Errorf("failed to stop state manager: %v", err)
} }
if err := e.stateManager.PersistState(context.Background()); err != nil { if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout() timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
@@ -448,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err != nil { if err != nil {
return fmt.Errorf("create rosenpass manager: %w", err) return fmt.Errorf("create rosenpass manager: %w", err)
} }
err := e.rpManager.Run() if err := e.rpManager.Run(); err != nil {
if err != nil {
return fmt.Errorf("run rosenpass manager: %w", err) return fmt.Errorf("run rosenpass manager: %w", err)
} }
} }
@@ -501,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
} }
if err := e.createFirewall(); err != nil { if err := e.createFirewall(); err != nil {
e.close()
return err return err
} }
@@ -766,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
if update.GetNetbirdConfig() != nil { if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig() wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns()) err := e.updateTURNs(wCfg.GetTurns())
@@ -805,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
} }
nm := update.GetNetworkMap() nm := update.GetNetworkMap()
if nm == nil { if nm == nil || update.SkipNetworkMapUpdate {
return nil return nil
} }
@@ -971,7 +963,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableSSHAuth, e.config.DisableSSHAuth,
) )
err = e.mgmClient.Sync(e.ctx, info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client // We want to cancel the operation of the whole client
@@ -1385,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
conn, ok := e.peerStore.PeerConn(msg.Key) conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)

View File

@@ -0,0 +1,79 @@
package internal
import (
"context"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun199",
WgAddr: "100.70.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
// Precondition
if engine.networkSerial != 0 {
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
}
resp := &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
SkipNetworkMapUpdate: true,
}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
if engine.networkSerial != 0 {
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
}
}
// Ensures handleSync exits early when NetworkMap is nil
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun198",
WgAddr: "100.70.0.2/24",
WgPrivateKey: key,
WgPort: 33101,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
}

View File

@@ -631,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client // feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse) updates := make(chan *mgmtProto.SyncResponse)
defer close(updates) defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error { syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates { for msg := range updates {
err := msgHandler(msg) err := msgHandler(msg)
if err != nil { if err != nil {

View File

@@ -1,9 +1,12 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"os"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string) error { func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client") log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName) log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
} }
return netIDs return netIDs
} }
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
log.Debugf("Setting env variable %s: %s", k, v)
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
} else {
log.Debugf("Env variable %s was set successfully", k)
}
}
}

View File

@@ -0,0 +1,34 @@
//go:build ios
package NetBirdSDK
import "github.com/netbirdio/netbird/client/internal/peer"
// EnvList is an exported struct to be bound by gomobile
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay
}

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import _ "golang.org/x/mobile/bind" import _ "golang.org/x/mobile/bind"

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage // PeerInfo describe information about the peers. It designed for the UI usage

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
import ( import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK package NetBirdSDK
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection // RoutesSelectionInfoCollection made for Java layer to get non default types as collection

View File

@@ -822,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
defer s.mutex.Unlock() defer s.mutex.Unlock()
if err := s.cleanupConnection(); err != nil { if err := s.cleanupConnection(); err != nil {
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err) log.Errorf("failed to shut down properly: %v", err)
return nil, err return nil, err
} }
@@ -914,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
} }
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) { if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
// todo review to update the status in case any type of error
log.Errorf("failed to cleanup connection: %v", err) log.Errorf("failed to cleanup connection: %v", err)
return nil, err return nil, err
} }

View File

@@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false return false
} }
// detectUtilLinuxLogin always returns false on JS/WASM
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM // executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM") logger.Errorf("PTY command execution not supported on JS/WASM")

View File

@@ -10,6 +10,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
"runtime"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool {
return supported return supported
} }
// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
// See https://bugs.debian.org/1078023 for details.
func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
if runtime.GOOS != "linux" {
return false
}
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "login", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("login --version failed (likely shadow-utils): %v", err)
return false
}
isUtilLinux := strings.Contains(string(output), "util-linux")
log.Debugf("util-linux login detected: %v", isUtilLinux)
return isUtilLinux
}
// createSuCommand creates a command using su -l -c for privilege switching // createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su") suPath, err := exec.LookPath("su")
@@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
return false return false
} }
logger.Infof("starting interactive shell: %s", execCmd.Path) logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
} }

View File

@@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false return false
} }
// detectUtilLinuxLogin always returns false on Windows
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty // executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand() command := session.RawCommand()

View File

@@ -138,7 +138,8 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig jwtConfig *JWTConfig
suSupportsPty bool suSupportsPty bool
loginIsUtilLinux bool
} }
type JWTConfig struct { type JWTConfig struct {
@@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
} }
s.suSupportsPty = s.detectSuPtySupport(ctx) s.suSupportsPty = s.detectSuPtySupport(ctx)
s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
ln, addrDesc, err := s.createListener(ctx, addr) ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil { if err != nil {

View File

@@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
switch runtime.GOOS { switch runtime.GOOS {
case "linux": case "linux":
// Special handling for Arch Linux without /etc/pam.d/remote p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { return p, a, nil
return loginPath, []string{"-f", username, "-p"}, nil
}
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly": case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
default: default:
@@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
} }
} }
// fileExists checks if a file exists (helper for login command logic) // getLinuxLoginCmd returns the login command for Linux systems.
// Handles differences between util-linux and shadow-utils login implementations.
func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
// Special handling for Arch Linux without /etc/pam.d/remote
var loginArgs []string
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
loginArgs = []string{"-f", username, "-p"}
} else {
loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
}
// util-linux login requires setsid -c to create a new session and set the
// controlling terminal. Without this, vhangup() kills the parent process.
// See https://bugs.debian.org/1078023 for details.
// TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
// to avoid external setsid dependency.
if !s.loginIsUtilLinux {
return loginPath, loginArgs
}
setsidPath, err := exec.LookPath("setsid")
if err != nil {
log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
return loginPath, loginArgs
}
args := append([]string{"-w", "-c", loginPath}, loginArgs...)
return setsidPath, args
}
// fileExists checks if a file exists
func (s *Server) fileExists(path string) bool { func (s *Server) fileExists(path string) bool {
_, err := os.Stat(path) _, err := os.Stat(path)
return err == nil return err == nil

View File

@@ -120,6 +120,26 @@ func (i *Info) SetFlags(
} }
} }
func (i *Info) CopyFlagsFrom(other *Info) {
i.SetFlags(
other.RosenpassEnabled,
other.RosenpassPermissive,
&other.ServerSSHAllowed,
other.DisableClientRoutes,
other.DisableServerRoutes,
other.DisableDNS,
other.DisableFirewall,
other.BlockLANAccess,
other.BlockInbound,
other.LazyConnectionEnabled,
&other.EnableSSHRoot,
&other.EnableSSHSFTP,
&other.EnableSSHLocalPortForwarding,
&other.EnableSSHRemotePortForwarding,
&other.DisableSSHAuth,
)
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string { func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx) md, hasMeta := metadata.FromOutgoingContext(ctx)

View File

@@ -8,6 +8,90 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
func TestInfo_CopyFlagsFrom(t *testing.T) {
origin := &Info{}
serverSSHAllowed := true
enableSSHRoot := true
enableSSHSFTP := false
enableSSHLocalPortForwarding := true
enableSSHRemotePortForwarding := false
disableSSHAuth := true
origin.SetFlags(
true, // RosenpassEnabled
false, // RosenpassPermissive
&serverSSHAllowed,
true, // DisableClientRoutes
false, // DisableServerRoutes
true, // DisableDNS
false, // DisableFirewall
true, // BlockLANAccess
false, // BlockInbound
true, // LazyConnectionEnabled
&enableSSHRoot,
&enableSSHSFTP,
&enableSSHLocalPortForwarding,
&enableSSHRemotePortForwarding,
&disableSSHAuth,
)
got := &Info{}
got.CopyFlagsFrom(origin)
if got.RosenpassEnabled != true {
t.Fatalf("RosenpassEnabled not copied: got %v", got.RosenpassEnabled)
}
if got.RosenpassPermissive != false {
t.Fatalf("RosenpassPermissive not copied: got %v", got.RosenpassPermissive)
}
if got.ServerSSHAllowed != true {
t.Fatalf("ServerSSHAllowed not copied: got %v", got.ServerSSHAllowed)
}
if got.DisableClientRoutes != true {
t.Fatalf("DisableClientRoutes not copied: got %v", got.DisableClientRoutes)
}
if got.DisableServerRoutes != false {
t.Fatalf("DisableServerRoutes not copied: got %v", got.DisableServerRoutes)
}
if got.DisableDNS != true {
t.Fatalf("DisableDNS not copied: got %v", got.DisableDNS)
}
if got.DisableFirewall != false {
t.Fatalf("DisableFirewall not copied: got %v", got.DisableFirewall)
}
if got.BlockLANAccess != true {
t.Fatalf("BlockLANAccess not copied: got %v", got.BlockLANAccess)
}
if got.BlockInbound != false {
t.Fatalf("BlockInbound not copied: got %v", got.BlockInbound)
}
if got.LazyConnectionEnabled != true {
t.Fatalf("LazyConnectionEnabled not copied: got %v", got.LazyConnectionEnabled)
}
if got.EnableSSHRoot != true {
t.Fatalf("EnableSSHRoot not copied: got %v", got.EnableSSHRoot)
}
if got.EnableSSHSFTP != false {
t.Fatalf("EnableSSHSFTP not copied: got %v", got.EnableSSHSFTP)
}
if got.EnableSSHLocalPortForwarding != true {
t.Fatalf("EnableSSHLocalPortForwarding not copied: got %v", got.EnableSSHLocalPortForwarding)
}
if got.EnableSSHRemotePortForwarding != false {
t.Fatalf("EnableSSHRemotePortForwarding not copied: got %v", got.EnableSSHRemotePortForwarding)
}
if got.DisableSSHAuth != true {
t.Fatalf("DisableSSHAuth not copied: got %v", got.DisableSSHAuth)
}
// ensure CopyFlagsFrom does not touch unrelated fields
origin.Hostname = "host-a"
got.Hostname = "host-b"
got.CopyFlagsFrom(origin)
if got.Hostname != "host-b" {
t.Fatalf("CopyFlagsFrom should not overwrite non-flag fields, got Hostname=%q", got.Hostname)
}
}
func Test_LocalWTVersion(t *testing.T) { func Test_LocalWTVersion(t *testing.T) {
got := GetInfo(context.TODO()) got := GetInfo(context.TODO())
want := "development" want := "development"

2
go.mod
View File

@@ -64,7 +64,7 @@ require (
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580= github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -394,23 +394,26 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil return nil
} }
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval { network, err := c.repo.GetAccountNetwork(ctx, accountID)
network, err := c.repo.GetAccountNetwork(ctx, accountID) if err != nil {
if err != nil { return nil, nil, nil, 0, err
return nil, nil, nil, 0, err }
}
if isRequiresApproval {
emptyMap := &types.NetworkMap{ emptyMap := &types.NetworkMap{
Network: network.Copy(), Network: network.Copy(),
} }
return peer, emptyMap, nil, 0, nil return peer, emptyMap, nil, 0, nil
} }
var ( if clientSerial > 0 && clientSerial == network.CurrentSerial() {
account *types.Account log.WithContext(ctx).Debugf("client serial %d matches current serial, skipping network map calculation", clientSerial)
err error return peer, nil, nil, 0, nil
) }
var account *types.Account
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID) account = c.getAccountFromHolderOrInit(accountID)
} else { } else {

View File

@@ -24,7 +24,7 @@ type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string) error UpdateAccountPeers(ctx context.Context, accountID string) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string) error BufferUpdateAccountPeers(ctx context.Context, accountID string) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context) StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)

View File

@@ -113,9 +113,9 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
} }
// GetValidatedPeerWithMap mocks base method. // GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer, clientSerial uint64) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p) ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p, clientSerial)
ret0, _ := ret[0].(*peer.Peer) ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap) ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks) ret2, _ := ret[2].([]*posture.Checks)
@@ -125,9 +125,9 @@ func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequires
} }
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap. // GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call { func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p, clientSerial any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p, clientSerial)
} }
// OnPeerConnected mocks base method. // OnPeerConnected mocks base method.

View File

@@ -104,6 +104,20 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
} }
} }
// ToSkipSyncResponse creates a minimal SyncResponse when the client already has the latest network map.
func ToSkipSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, checks []*posture.Checks, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
response := &proto.SyncResponse{
SkipNetworkMapUpdate: true,
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
return response
}
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{ response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),

View File

@@ -134,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
} }
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip) log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
}()
// todo introduce something more meaningful with the key expiration/rotation // todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil { if s.appMetrics != nil {
@@ -194,7 +190,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
} }
if s.logBlockedPeers { if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
} }
if s.blockPeersWithSameConfig { if s.blockPeersWithSameConfig {
s.syncSem.Add(-1) s.syncSem.Add(-1)
@@ -222,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return err return err
} }
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
@@ -235,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
} }
}() }()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
@@ -246,7 +239,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String()) metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash) s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncReq.GetNetworkMapSerial())
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1) s.syncSem.Add(-1)
@@ -352,7 +345,7 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID) s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
s.secretsManager.CancelRefresh(peer.ID) s.secretsManager.CancelRefresh(peer.ID)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
} }
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
@@ -525,7 +518,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
reqStart := time.Now() reqStart := time.Now()
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
sRealIP := realIP.String() sRealIP := realIP.String()
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
loginReq := &proto.LoginRequest{} loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq) peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -537,7 +529,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
metahashed := metaHash(peerMeta, sRealIP) metahashed := metaHash(peerMeta, sRealIP)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers { if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
} }
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
@@ -561,16 +553,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
//nolint //nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
defer func() { defer func() {
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
} }
took := time.Since(reqStart)
if took > 7*time.Second {
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
}
}() }()
if loginReq.GetMeta() == nil { if loginReq.GetMeta() == nil {
@@ -604,16 +592,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err) return nil, mapError(ctx, err)
} }
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer") return nil, status.Errorf(codes.Internal, "failed logging in peer")
} }
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
key, err := s.secretsManager.GetWGKey() key, err := s.secretsManager.GetWGKey()
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err) log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
@@ -718,7 +702,12 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "failed to get peer groups %s", err) return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
} }
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) var plainResp *proto.SyncResponse
if networkMap == nil {
plainResp = ToSkipSyncResponse(ctx, s.config, peer, turnToken, relayToken, postureChecks, settings.Extra, peerGroups)
} else {
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
}
key, err := s.secretsManager.GetWGKey() key, err := s.secretsManager.GetWGKey()
if err != nil { if err != nil {
@@ -730,12 +719,10 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "error handling request") return status.Errorf(codes.Internal, "error handling request")
} }
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{ err = srv.Send(&proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(), WgPubKey: key.PublicKey().String(),
Body: encryptedResp, Body: encryptedResp,
}) })
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
@@ -750,10 +737,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {
@@ -813,10 +796,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {

View File

@@ -167,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI
relayCancel := make(chan struct{}, 1) relayCancel := make(chan struct{}, 1)
m.relayCancelMap[peerID] = relayCancel m.relayCancelMap[peerID] = relayCancel
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel) go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID) log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID)
} }
} }
@@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc
for { for {
select { select {
case <-cancel: case <-cancel:
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID) log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID)
return return
case <-ticker.C: case <-ticker.C:
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID) m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
@@ -193,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac
for { for {
select { select {
case <-cancel: case <-cancel:
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID) log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID)
return return
case <-ticker.C: case <-ticker.C:
m.pushNewRelayTokens(ctx, accountID, peerID) m.pushNewRelayTokens(ctx, accountID, peerID)

View File

@@ -295,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err return err
} }
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return err return err
} }
if oldSettings.Extra != nil && newSettings.Extra != nil &&
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to approve pending peers: %w", err)
}
if approvedCount > 0 {
log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID)
updateAccountPeers = true
}
}
if oldSettings.NetworkRange != newSettings.NetworkRange { if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err return err
@@ -372,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil return newSettings, nil
} }
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit { if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -386,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
} }
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
if err != nil {
return err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
} }
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
@@ -787,6 +790,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID) accountIDString := fmt.Sprintf("%v", accountID)
if ctx == nil {
ctx = context.Background()
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -1607,8 +1617,8 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
} }
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, NetworkMapSerial: clientSerial}, accountID)
if err != nil { if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
} }

View File

@@ -107,7 +107,7 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)

View File

@@ -2058,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
} }
func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
accountID := account.Id
userID := account.Users[account.CreatedBy].Id
ctx := context.Background()
newSettings := account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: true,
}
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
peer1.Status.RequiresApproval = true
peer2.Status.RequiresApproval = true
peer3.Status.RequiresApproval = false
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3))
newSettings = account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: false,
}
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range accountPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID)
}
}
func TestAccount_GetExpiredPeers(t *testing.T) { func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct { type test struct {
name string name string
@@ -3107,7 +3144,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
b.ResetTimer() b.ResetTimer()
start := time.Now() start := time.Now()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, 0)
assert.NoError(b, err) assert.NoError(b, err)
} }

View File

@@ -141,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
} }
if userAuth.AccountId != accountId { if userAuth.AccountId != accountId {
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId userAuth.AccountId = accountId
} }

View File

@@ -127,7 +127,7 @@ type MockIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
} }
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error {
return nil return nil
} }

View File

@@ -10,7 +10,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies // IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface { type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)

View File

@@ -37,7 +37,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -177,9 +177,9 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
} }
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil { if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, clientSerial)
} }
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }

View File

@@ -172,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
} }
} }
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil { if err != nil {
@@ -645,7 +645,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
} }
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer) p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer, 0)
return p, nmap, pc, err return p, nmap, pc, err
} }
@@ -731,7 +731,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
} }
} }
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer, sync.NetworkMapSerial)
} }
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
@@ -783,7 +783,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err return nil, nil, nil, err
} }
startTransaction := time.Now()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil { if err != nil {
@@ -853,8 +852,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err return nil, nil, nil, err
} }
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil { if err != nil {
@@ -862,7 +859,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
} }
} }
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer, 0)
return p, nmap, pc, err return p, nmap, pc, err
} }

View File

@@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error {
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil { if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil return false, nil
} }
@@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
if check == nil { if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil return false, nil
} }

View File

@@ -5,9 +5,6 @@ package settings
import ( import (
"context" "context"
"fmt" "fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings" "github.com/netbirdio/netbird/management/server/integrations/extra_settings"
@@ -48,11 +45,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
} }
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
}()
if userID != activity.SystemInitiator { if userID != activity.SystemInitiator {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil { if err != nil {

View File

@@ -27,7 +27,6 @@ import (
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbcontext "github.com/netbirdio/netbird/management/server/context"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -288,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
if s.metrics != nil { if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took) s.metrics.StoreMetrics().CountPersistenceDuration(took)
} }
log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds()) log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds())
return err return err
} }
@@ -413,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW
return nil return nil
} }
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true).
Update("peer_status_requires_approval", false)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error)
}
return int(result.RowsAffected), nil
}
// SaveUsers saves the given list of users to the database. // SaveUsers saves the given list of users to the database.
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
if len(users) == 0 { if len(users) == 0 {
@@ -583,16 +594,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var user types.User var user types.User
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID) result := tx.Take(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID) return nil, status.NewUserNotFoundError(userID)
@@ -2152,16 +2160,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
} }
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var accountNetwork types.AccountNetwork var accountNetwork types.AccountNetwork
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError(accountID)
} }
@@ -2171,16 +2176,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
} }
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var peer nbpeer.Peer var peer nbpeer.Peer
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey) result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -2229,11 +2231,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var user types.User var user types.User
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID) result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID) return status.NewUserNotFoundError(userID)
@@ -2491,16 +2490,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
} }
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var setupKey types.SetupKey var setupKey types.SetupKey
result := tx.WithContext(ctx). result := tx.
Take(&setupKey, GetKeyQueryCondition(s), key) Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil { if result.Error != nil {
@@ -2514,10 +2510,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
ctx, cancel := getDebuggingCtx(ctx) result := s.db.Model(&types.SetupKey{}).
defer cancel()
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID). Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{ Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"), "used_times": gorm.Expr("used_times + 1"),
@@ -2537,11 +2530,8 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var groupID string var groupID string
_ = s.db.WithContext(ctx).Model(types.Group{}). _ = s.db.Model(types.Group{}).
Select("id"). Select("id").
Where("account_id = ? AND name = ?", accountID, "All"). Where("account_id = ? AND name = ?", accountID, "All").
Limit(1). Limit(1).
@@ -2569,9 +2559,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
// AddPeerToGroup adds a peer to a group // AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
peer := &types.GroupPeer{ peer := &types.GroupPeer{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
@@ -2768,10 +2755,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
} }
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
ctx, cancel := getDebuggingCtx(ctx) if err := s.db.Create(peer).Error; err != nil {
defer cancel()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err) return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
} }
@@ -2897,10 +2881,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
} }
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
ctx, cancel := getDebuggingCtx(ctx) result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
defer cancel()
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store") return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -4022,36 +4003,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
return groupPeers, nil return groupPeers, nil
} }
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
}
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
}
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
}
go func() {
select {
case <-ctx.Done():
case <-grpcCtx.Done():
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
}
}()
return ctx, cancel
}
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}). result := s.db.Model(&types.Account{}).
@@ -4091,7 +4042,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
Network: &types.Network{Net: ipNet}, Network: &types.Network{Net: ipNet},
} }
result := s.db.WithContext(ctx). result := s.db.
Model(&types.Account{}). Model(&types.Account{}).
Where(idQueryCondition, accountID). Where(idQueryCondition, accountID).
Updates(&patch) Updates(&patch)

View File

@@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
}) })
} }
} }
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
accountID := "test-account"
ctx := context.Background()
account := newAccountWithId(ctx, accountID, "testuser", "example.com")
err := store.SaveAccount(ctx, account)
require.NoError(t, err)
peers := []*nbpeer.Peer{
{
ID: "peer1",
AccountID: accountID,
DNSLabel: "peer1.netbird.cloud",
Key: "peer1-key",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer2",
AccountID: accountID,
DNSLabel: "peer2.netbird.cloud",
Key: "peer2-key",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer3",
AccountID: accountID,
DNSLabel: "peer3.netbird.cloud",
Key: "peer3-key",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{
RequiresApproval: false,
LastSeen: time.Now().UTC(),
},
},
}
for _, peer := range peers {
err = store.AddPeerToAccount(ctx, peer)
require.NoError(t, err)
}
t.Run("approve all pending peers", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 2, count)
allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range allPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID)
}
})
t.Run("no peers to approve", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 0, count)
})
t.Run("non-existent account", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, "non-existent")
require.NoError(t, err)
assert.Equal(t, 0, count)
})
})
}

View File

@@ -143,6 +143,7 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error DeletePeer(ctx context.Context, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)

View File

@@ -16,7 +16,6 @@ type GRPCMetrics struct {
meter metric.Meter meter metric.Meter
syncRequestsCounter metric.Int64Counter syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter syncRequestsBlockedCounter metric.Int64Counter
syncRequestHighLatencyCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter loginRequestHighLatencyCounter metric.Int64Counter
@@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err return nil, err
} }
syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
)
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
metric.WithUnit("1"), metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
meter: meter, meter: meter,
syncRequestsCounter: syncRequestsCounter, syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter, syncRequestsBlockedCounter: syncRequestsBlockedCounter,
syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
loginRequestsCounter: loginRequestsCounter, loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter, loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
@@ -175,9 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests // CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
if duration > HighLatencyThreshold {
grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
} }
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, r.WithContext(ctx)) h.ServeHTTP(w, r.WithContext(ctx))
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {
if userAuth.AccountId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
}
if userAuth.UserId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
}
}
if w.Status() > 399 { if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
} else { } else {

View File

@@ -15,6 +15,9 @@ type PeerSync struct {
// UpdateAccountPeers indicate updating account peers, // UpdateAccountPeers indicate updating account peers,
// which occurs when the peer's metadata is updated // which occurs when the peer's metadata is updated
UpdateAccountPeers bool UpdateAccountPeers bool
// NetworkMapSerial is the last known network map serial number on the client.
// Used to skip network map recalculation if client already has the latest.
NetworkMapSerial uint64
} }
// PeerLogin used as a data object between the gRPC API and Manager on Login request. // PeerLogin used as a data object between the gRPC API and Manager on Login request.

View File

@@ -0,0 +1,31 @@
package peerid
import (
"crypto/sha256"
v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
"github.com/netbirdio/netbird/shared/relay/messages"
)
var (
// HealthCheckPeerID is the hashed peer ID for health check connections
HealthCheckPeerID = messages.HashID("healthcheck-agent")
// DummyAuthToken is a structurally valid auth token for health check.
// The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload).
DummyAuthToken = createDummyToken()
)
func createDummyToken() []byte {
token := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: make([]byte, sha256.Size),
Payload: []byte("healthcheck"),
}
return token.Marshal()
}
// IsHealthCheck checks if the given peer ID is the health check agent
func IsHealthCheck(peerID *messages.PeerID) bool {
return peerID != nil && *peerID == HealthCheckPeerID
}

View File

@@ -7,8 +7,10 @@ import (
"github.com/coder/websocket" "github.com/coder/websocket"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/shared/relay"
"github.com/netbirdio/netbird/shared/relay/messages"
) )
func dialWS(ctx context.Context, address url.URL) error { func dialWS(ctx context.Context, address url.URL) error {
@@ -30,7 +32,18 @@ func dialWS(ctx context.Context, address url.URL) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to websocket: %w", err) return fmt.Errorf("failed to connect to websocket: %w", err)
} }
defer func() {
_ = conn.CloseNow()
}()
authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken)
if err != nil {
return fmt.Errorf("failed to marshal auth message: %w", err)
}
if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil {
return fmt.Errorf("failed to write auth message: %w", err)
}
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
return nil return nil
} }

View File

@@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
} }
if err != nil { if err != nil {
return nil, err return peerID, err
} }
h.peerID = peerID h.peerID = peerID
return peerID, nil return peerID, nil
@@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
} }
if err := h.validator.Validate(authPayload); err != nil { if err := h.validator.Validate(authPayload); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
} }
return rawPeerID, nil return rawPeerID, nil

View File

@@ -12,6 +12,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/relay/server/store"
@@ -123,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) {
} }
peerID, err := h.handshakeReceive() peerID, err := h.handshakeReceive()
if err != nil { if err != nil {
log.Errorf("failed to handshake: %s", err) if peerid.IsHealthCheck(peerID) {
log.Debugf("health check connection from %s", conn.RemoteAddr())
} else {
log.Errorf("failed to handshake: %s", err)
}
if cErr := conn.Close(); cErr != nil { if cErr := conn.Close(); cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
} }

View File

@@ -13,7 +13,7 @@ import (
type Client interface { type Client interface {
io.Closer io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error Sync(ctx context.Context, sysInfo *system.Info, networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error) GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)

View File

@@ -313,7 +313,7 @@ func TestClient_Sync(t *testing.T) {
defer cancel() defer cancel()
go func() { go func() {
err = client.Sync(ctx, info, func(msg *mgmtProto.SyncResponse) error { err = client.Sync(ctx, info, 0, func(msg *mgmtProto.SyncResponse) error {
ch <- msg ch <- msg
return nil return nil
}) })

View File

@@ -6,14 +6,14 @@ package common
// //
// | Value | Flag | OAuth Parameters | // | Value | Flag | OAuth Parameters |
// |-------|----------------------|-----------------------------------------| // |-------|----------------------|-----------------------------------------|
// | 0 | LoginFlagPromptLogin | prompt=select_account login | // | 0 | LoginFlagPromptLogin | prompt=login |
// | 1 | LoginFlagMaxAge0 | max_age=0 & prompt=select_account | // | 1 | LoginFlagMaxAge0 | max_age=0 |
type LoginFlag uint8 type LoginFlag uint8
const ( const (
// LoginFlagPromptLogin adds prompt=select_account login to the authorization request // LoginFlagPromptLogin adds prompt=login to the authorization request
LoginFlagPromptLogin LoginFlag = iota LoginFlagPromptLogin LoginFlag = iota
// LoginFlagMaxAge0 adds max_age=0 and prompt=select_account to the authorization request // LoginFlagMaxAge0 adds max_age=0 to the authorization request
LoginFlagMaxAge0 LoginFlagMaxAge0
// LoginFlagNone disables all login flags // LoginFlagNone disables all login flags
LoginFlagNone LoginFlagNone

View File

@@ -110,7 +110,7 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function // Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error { func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error {
operation := func() error { operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState()) log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState() connState := c.conn.GetState()
@@ -128,7 +128,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err return err
} }
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) return c.handleStream(ctx, *serverPubKey, sysInfo, networkSerial, msgHandler)
} }
err := backoff.Retry(operation, defaultBackoff(ctx)) err := backoff.Retry(operation, defaultBackoff(ctx))
@@ -140,11 +140,11 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
} }
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
msgHandler func(msg *proto.SyncResponse) error) error { networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error {
ctx, cancelStream := context.WithCancel(ctx) ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream() defer cancelStream()
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo) stream, err := c.connectToStream(ctx, serverPubKey, sysInfo, networkSerial)
if err != nil { if err != nil {
log.Debugf("failed to open Management Service stream: %s", err) log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied { if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -186,7 +186,8 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
ctx, cancelStream := context.WithCancel(c.ctx) ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream() defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo) // GetNetworkMap doesn't have a serial to send, so we pass 0
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo, 0)
if err != nil { if err != nil {
log.Debugf("failed to open Management Service stream: %s", err) log.Debugf("failed to open Management Service stream: %s", err)
return nil, err return nil, err
@@ -219,8 +220,17 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
return decryptedResp.GetNetworkMap(), nil return decryptedResp.GetNetworkMap(), nil
} }
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) { func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, networkSerial uint64) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)} // Always compute latest system info to ensure up-to-date PeerSystemMeta on first and subsequent syncs
recomputed := system.GetInfo(c.ctx)
if sysInfo != nil {
recomputed.CopyFlagsFrom(sysInfo)
// carry over posture files if any were computed
if len(sysInfo.Files) > 0 {
recomputed.Files = sysInfo.Files
}
}
req := &proto.SyncRequest{Meta: infoToMetaData(recomputed), NetworkMapSerial: networkSerial}
myPrivateKey := c.key myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey() myPublicKey := myPrivateKey.PublicKey()

View File

@@ -12,7 +12,7 @@ import (
type MockClient struct { type MockClient struct {
CloseFunc func() error CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error SyncFunc func(ctx context.Context, sysInfo *system.Info, networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error) GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
@@ -33,11 +33,11 @@ func (m *MockClient) Close() error {
return m.CloseFunc() return m.CloseFunc()
} }
func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error { func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error {
if m.SyncFunc == nil { if m.SyncFunc == nil {
return nil return nil
} }
return m.SyncFunc(ctx, sysInfo, msgHandler) return m.SyncFunc(ctx, sysInfo, networkSerial, msgHandler)
} }
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {

View File

@@ -7,12 +7,13 @@
package proto package proto
import ( import (
reflect "reflect"
sync "sync"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
durationpb "google.golang.org/protobuf/types/known/durationpb" durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb" timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
) )
const ( const (
@@ -343,6 +344,8 @@ type SyncRequest struct {
// Meta data of the peer // Meta data of the peer
Meta *PeerSystemMeta `protobuf:"bytes,1,opt,name=meta,proto3" json:"meta,omitempty"` Meta *PeerSystemMeta `protobuf:"bytes,1,opt,name=meta,proto3" json:"meta,omitempty"`
// Optional: last known NetworkMap serial number on the client
NetworkMapSerial uint64 `protobuf:"varint,2,opt,name=networkMapSerial,proto3" json:"networkMapSerial,omitempty"`
} }
func (x *SyncRequest) Reset() { func (x *SyncRequest) Reset() {
@@ -384,6 +387,13 @@ func (x *SyncRequest) GetMeta() *PeerSystemMeta {
return nil return nil
} }
func (x *SyncRequest) GetNetworkMapSerial() uint64 {
if x != nil {
return x.NetworkMapSerial
}
return 0
}
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs) // SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
type SyncResponse struct { type SyncResponse struct {
state protoimpl.MessageState state protoimpl.MessageState
@@ -402,6 +412,8 @@ type SyncResponse struct {
NetworkMap *NetworkMap `protobuf:"bytes,5,opt,name=NetworkMap,proto3" json:"NetworkMap,omitempty"` NetworkMap *NetworkMap `protobuf:"bytes,5,opt,name=NetworkMap,proto3" json:"NetworkMap,omitempty"`
// Posture checks to be evaluated by client // Posture checks to be evaluated by client
Checks []*Checks `protobuf:"bytes,6,rep,name=Checks,proto3" json:"Checks,omitempty"` Checks []*Checks `protobuf:"bytes,6,rep,name=Checks,proto3" json:"Checks,omitempty"`
// Indicates whether the client should skip updating the network map
SkipNetworkMapUpdate bool `protobuf:"varint,7,opt,name=skipNetworkMapUpdate,proto3" json:"skipNetworkMapUpdate,omitempty"`
} }
func (x *SyncResponse) Reset() { func (x *SyncResponse) Reset() {
@@ -478,6 +490,13 @@ func (x *SyncResponse) GetChecks() []*Checks {
return nil return nil
} }
func (x *SyncResponse) GetSkipNetworkMapUpdate() bool {
if x != nil {
return x.SkipNetworkMapUpdate
}
return false
}
type SyncMetaRequest struct { type SyncMetaRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@@ -3518,33 +3537,39 @@ var file_management_proto_rawDesc = []byte{
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12,
0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x6f, 0x64, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03,
0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3d, 0x0a, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x69, 0x0a,
0x0b, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a, 0x04, 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a, 0x04,
0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74,
0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x22, 0xdb, 0x02, 0x0a, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x12, 0x2a, 0x0a, 0x10,
0x0c, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c,
0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x61, 0x70, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x22, 0x8f, 0x03, 0x0a, 0x0c, 0x53, 0x79, 0x6e,
0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x63, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74,
0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b,
0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65,
0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74,
0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16,
0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66,
0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72,
0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43,
0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65,
0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x36, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72,
0x6b, 0x4d, 0x61, 0x70, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70,
0x61, 0x70, 0x52, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x2a, 0x74, 0x79, 0x12, 0x36, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x52, 0x0a,
0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x41, 0x0a, 0x0f, 0x53, 0x79, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x2a, 0x0a, 0x06, 0x43, 0x68,
0x65, 0x63, 0x6b, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e,
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x52, 0x06,
0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x73, 0x6b, 0x69, 0x70, 0x4e, 0x65,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x07,
0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x73, 0x6b, 0x69, 0x70, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
0x6b, 0x4d, 0x61, 0x70, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0x41, 0x0a, 0x0f, 0x53, 0x79,
0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a,
0x04, 0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73,

View File

@@ -63,6 +63,8 @@ message EncryptedMessage {
message SyncRequest { message SyncRequest {
// Meta data of the peer // Meta data of the peer
PeerSystemMeta meta = 1; PeerSystemMeta meta = 1;
// Optional: last known NetworkMap serial number on the client
uint64 networkMapSerial = 2;
} }
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs) // SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
@@ -85,6 +87,9 @@ message SyncResponse {
// Posture checks to be evaluated by client // Posture checks to be evaluated by client
repeated Checks Checks = 6; repeated Checks Checks = 6;
// Indicates whether the client should skip updating the network map
bool skipNetworkMapUpdate = 7;
} }
message SyncMetaRequest { message SyncMetaRequest {