mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 03:06:38 +00:00
Compare commits
22 Commits
v0.60.5
...
sync-clien
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72513d7522 | ||
|
|
a1f1bf1f19 | ||
|
|
b5dec3df39 | ||
|
|
447cd287f5 | ||
|
|
5748bdd64e | ||
|
|
08f31fbcb3 | ||
|
|
932c02eaab | ||
|
|
abcbde26f9 | ||
|
|
90e3b8009f | ||
|
|
94d34dc0c5 | ||
|
|
44851e06fb | ||
|
|
3f4f825ec1 | ||
|
|
f538e6e9ae | ||
|
|
cb6b086164 | ||
|
|
71b6855e09 | ||
|
|
9bdc4908fb | ||
|
|
031ab11178 | ||
|
|
20f5f00635 | ||
|
|
fc141cf3a3 | ||
|
|
d0c65fa08e | ||
|
|
f241bfa339 | ||
|
|
4b2cd97d5f |
11
.githooks/pre-push
Executable file
11
.githooks/pre-push
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "Running pre-push hook..."
|
||||||
|
if ! make lint; then
|
||||||
|
echo ""
|
||||||
|
echo "Hint: To push without verification, run:"
|
||||||
|
echo " git push --no-verify"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "All checks passed!"
|
||||||
@@ -136,6 +136,14 @@ checked out and set up:
|
|||||||
go mod tidy
|
go mod tidy
|
||||||
```
|
```
|
||||||
|
|
||||||
|
6. Configure Git hooks for automatic linting:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make setup-hooks
|
||||||
|
```
|
||||||
|
|
||||||
|
This will configure Git to run linting automatically before each push, helping catch issues early.
|
||||||
|
|
||||||
### Dev Container Support
|
### Dev Container Support
|
||||||
|
|
||||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||||
|
|||||||
27
Makefile
Normal file
27
Makefile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
.PHONY: lint lint-all lint-install setup-hooks
|
||||||
|
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||||
|
|
||||||
|
# Install golangci-lint locally if needed
|
||||||
|
$(GOLANGCI_LINT):
|
||||||
|
@echo "Installing golangci-lint..."
|
||||||
|
@mkdir -p ./bin
|
||||||
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
|
# Lint only changed files (fast, for pre-push)
|
||||||
|
lint: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on changed files..."
|
||||||
|
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
|
||||||
|
|
||||||
|
# Lint entire codebase (slow, matches CI)
|
||||||
|
lint-all: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on all files..."
|
||||||
|
@$(GOLANGCI_LINT) run --timeout=12m
|
||||||
|
|
||||||
|
# Just install the linter
|
||||||
|
lint-install: $(GOLANGCI_LINT)
|
||||||
|
|
||||||
|
# Setup git hooks for all developers
|
||||||
|
setup-hooks:
|
||||||
|
@git config core.hooksPath .githooks
|
||||||
|
@chmod +x .githooks/pre-push
|
||||||
|
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||||
@@ -27,7 +27,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
|
tableRaw = "raw"
|
||||||
|
tableSecurity = "security"
|
||||||
|
|
||||||
chainNameNatPrerouting = "PREROUTING"
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
@@ -91,7 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
|||||||
var err error
|
var err error
|
||||||
r.filterTable, err = r.loadFilterTable()
|
r.filterTable, err = r.loadFilterTable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to load filter table, skipping accept rules: %v", err)
|
log.Debugf("ip filter table not found: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
@@ -183,6 +187,33 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
|||||||
return nil, errFilterTableNotFound
|
return nil, errFilterTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hookName(hook *nftables.ChainHook) string {
|
||||||
|
if hook == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
switch *hook {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
return chainNameForward
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
return chainNameInput
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("hook(%d)", *hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func familyName(family nftables.TableFamily) string {
|
||||||
|
switch family {
|
||||||
|
case nftables.TableFamilyIPv4:
|
||||||
|
return "ip"
|
||||||
|
case nftables.TableFamilyIPv6:
|
||||||
|
return "ip6"
|
||||||
|
case nftables.TableFamilyINet:
|
||||||
|
return "inet"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("family(%d)", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingFw,
|
Name: chainNameRoutingFw,
|
||||||
@@ -930,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||||
func (r *router) acceptForwardRules() error {
|
func (r *router) acceptForwardRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.acceptFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.acceptExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -944,11 +988,11 @@ func (r *router) acceptForwardRules() error {
|
|||||||
// Try iptables first and fallback to nftables if iptables is not available
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// filter table exists but iptables is not
|
// iptables is not available but the filter table exists
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
fw = "nftables"
|
fw = "nftables"
|
||||||
return r.acceptFilterRulesNftables()
|
return r.acceptFilterRulesNftables(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.acceptFilterRulesIptables(ipt)
|
return r.acceptFilterRulesIptables(ipt)
|
||||||
@@ -959,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables forward rule: %v", rule)
|
log.Debugf("added iptables forward rule: %v", rule)
|
||||||
}
|
}
|
||||||
@@ -967,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables input rule: %v", inputRule)
|
log.Debugf("added iptables input rule: %v", inputRule)
|
||||||
}
|
}
|
||||||
@@ -987,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
|
|||||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) acceptFilterRulesNftables() error {
|
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||||
|
// This is used when iptables is not available.
|
||||||
|
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
forwardChain := &nftables.Chain{
|
||||||
|
Name: chainNameForward,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertForwardAcceptRules(forwardChain, intf)
|
||||||
|
|
||||||
|
inputChain := &nftables.Chain{
|
||||||
|
Name: chainNameInput,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookInput,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertInputAcceptRule(inputChain, intf)
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||||
|
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||||
|
func (r *router) acceptExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||||
|
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||||
|
|
||||||
|
switch *chain.Hooknum {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
r.insertForwardAcceptRules(chain, intf)
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
r.insertInputAcceptRule(chain, intf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush external chain rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||||
iifRule := &nftables.Rule{
|
iifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameForward,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1021,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
Data: intf,
|
Data: intf,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
oifRule := &nftables.Rule{
|
oifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameForward,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(oifRule)
|
r.conn.InsertRule(oifRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||||
inputRule := &nftables.Rule{
|
inputRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameInput,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookInput,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1058,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
UserData: []byte(userDataAcceptInputRule),
|
UserData: []byte(userDataAcceptInputRule),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(inputRule)
|
r.conn.InsertRule(inputRule)
|
||||||
|
|
||||||
return r.conn.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRules() error {
|
func (r *router) removeAcceptFilterRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||||
return r.removeAcceptFilterRulesNftables()
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.removeAcceptFilterRulesIptables(ipt)
|
return r.removeAcceptFilterRulesIptables(ipt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list chains: %v", err)
|
return fmt.Errorf("list chains: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
if chain.Table.Name != r.filterTable.Name {
|
if chain.Table.Name != table.Name {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1091,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||||
|
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||||
|
// ensuring cleanup works even after a crash or if chains changed.
|
||||||
|
func (r *router) removeExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||||
|
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||||
|
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||||
|
func (r *router) findExternalChains() []*nftables.Chain {
|
||||||
|
var chains []*nftables.Chain
|
||||||
|
|
||||||
|
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||||
|
|
||||||
|
for _, family := range families {
|
||||||
|
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get rules: %v", err)
|
log.Debugf("list chains for family %d: %v", family, err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, chain := range allChains {
|
||||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
if r.isExternalChain(chain) {
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
chains = append(chains, chain)
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
return chains
|
||||||
return fmt.Errorf(flushError, err)
|
}
|
||||||
|
|
||||||
|
func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||||
|
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Skip all iptables-managed tables in the ip family
|
||||||
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Type != nftables.ChainTypeFilter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesTable(name string) bool {
|
||||||
|
switch name {
|
||||||
|
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||||
@@ -1119,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
|||||||
@@ -107,10 +107,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
switch p.providerConfig.LoginFlag {
|
switch p.providerConfig.LoginFlag {
|
||||||
case common.LoginFlagPromptLogin:
|
case common.LoginFlagPromptLogin:
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login select_account"))
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
case common.LoginFlagMaxAge0:
|
case common.LoginFlagMaxAge0:
|
||||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "select_account"))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if p.providerConfig.LoginHint != "" {
|
if p.providerConfig.LoginHint != "" {
|
||||||
|
|||||||
@@ -15,9 +15,8 @@ import (
|
|||||||
|
|
||||||
func TestPromptLogin(t *testing.T) {
|
func TestPromptLogin(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
promptSelectAccountLogin = "prompt=login+select_account"
|
promptLogin = "prompt=login"
|
||||||
promptSelectAccount = "prompt=select_account"
|
maxAge0 = "max_age=0"
|
||||||
maxAge0 = "max_age=0"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
@@ -27,14 +26,14 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
expectContains []string
|
expectContains []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Prompt login with select account",
|
name: "Prompt login",
|
||||||
loginFlag: mgm.LoginFlagPromptLogin,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
expectContains: []string{promptSelectAccountLogin},
|
expectContains: []string{promptLogin},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Max age 0 with select account",
|
name: "Max age 0",
|
||||||
loginFlag: mgm.LoginFlagMaxAge0,
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
expectContains: []string{maxAge0, promptSelectAccount},
|
expectContains: []string{maxAge0},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disable prompt login",
|
name: "Disable prompt login",
|
||||||
|
|||||||
@@ -273,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
|
c.engine = engine
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
@@ -293,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
engine := c.engine
|
|
||||||
c.engine = nil
|
c.engine = nil
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if engine != nil && engine.wgInterface != nil {
|
// todo: consider to remove this condition. Is not thread safe.
|
||||||
|
// We should always call Stop(), but we need to verify that it is idempotent
|
||||||
|
if engine.wgInterface != nil {
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||||
|
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||||
|
for i, ip := range ips {
|
||||||
|
ips[i] = ip.Unmap()
|
||||||
|
}
|
||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
f.cache.set(domain, question.Qtype, ips)
|
f.cache.set(domain, question.Qtype, ips)
|
||||||
|
|||||||
@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
if e.connMgr != nil {
|
if e.connMgr != nil {
|
||||||
e.connMgr.Close()
|
e.connMgr.Close()
|
||||||
@@ -292,21 +291,12 @@ func (e *Engine) Stop() error {
|
|||||||
}
|
}
|
||||||
log.Info("Network monitor: stopped")
|
log.Info("Network monitor: stopped")
|
||||||
|
|
||||||
if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" {
|
|
||||||
log.Info("removing peers before dns")
|
|
||||||
if err := e.removeAllPeers(); err != nil {
|
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := e.stopSSHServer(); err != nil {
|
if err := e.stopSSHServer(); err != nil {
|
||||||
log.Warnf("failed to stop SSH server: %v", err)
|
log.Warnf("failed to stop SSH server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.cleanupSSHConfig()
|
e.cleanupSSHConfig()
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
|
||||||
e.stopDNSServer()
|
|
||||||
|
|
||||||
if e.ingressGatewayMgr != nil {
|
if e.ingressGatewayMgr != nil {
|
||||||
if err := e.ingressGatewayMgr.Close(); err != nil {
|
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||||
log.Warnf("failed to cleanup forward rules: %v", err)
|
log.Warnf("failed to cleanup forward rules: %v", err)
|
||||||
@@ -314,33 +304,28 @@ func (e *Engine) Stop() error {
|
|||||||
e.ingressGatewayMgr = nil
|
e.ingressGatewayMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
e.stopDNSForwarder()
|
if e.srWatcher != nil {
|
||||||
|
e.srWatcher.Close()
|
||||||
|
}
|
||||||
|
|
||||||
if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" {
|
log.Info("cleaning up status recorder states")
|
||||||
log.Info("removing peers before routes")
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
if err := e.removeAllPeers(); err != nil {
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||||
}
|
|
||||||
|
if err := e.removeAllPeers(); err != nil {
|
||||||
|
log.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.routeManager != nil {
|
if e.routeManager != nil {
|
||||||
e.routeManager.Stop(e.stateManager)
|
e.routeManager.Stop(e.stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
e.stopDNSForwarder()
|
||||||
e.srWatcher.Close()
|
|
||||||
}
|
|
||||||
log.Info("cleaning up status recorder states")
|
|
||||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
|
||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
|
||||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
|
||||||
|
|
||||||
if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" {
|
// stop/restore DNS after peers are closed but before interface goes down
|
||||||
log.Info("removing peers after dns and routes")
|
// so dbus and friends don't complain because of a missing interface
|
||||||
if err := e.removeAllPeers(); err != nil {
|
e.stopDNSServer()
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
@@ -353,16 +338,18 @@ func (e *Engine) Stop() error {
|
|||||||
e.flowManager.Close()
|
e.flowManager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
defer cancel()
|
defer stateCancel()
|
||||||
|
|
||||||
if err := e.stateManager.Stop(ctx); err != nil {
|
if err := e.stateManager.Stop(stateCtx); err != nil {
|
||||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
log.Errorf("failed to stop state manager: %v", err)
|
||||||
}
|
}
|
||||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
timeout := e.calculateShutdownTimeout()
|
timeout := e.calculateShutdownTimeout()
|
||||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
@@ -448,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
err := e.rpManager.Run()
|
if err := e.rpManager.Run(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -501,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := e.createFirewall(); err != nil {
|
if err := e.createFirewall(); err != nil {
|
||||||
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -766,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
wCfg := update.GetNetbirdConfig()
|
wCfg := update.GetNetbirdConfig()
|
||||||
err := e.updateTURNs(wCfg.GetTurns())
|
err := e.updateTURNs(wCfg.GetTurns())
|
||||||
@@ -805,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nm := update.GetNetworkMap()
|
nm := update.GetNetworkMap()
|
||||||
if nm == nil {
|
if nm == nil || update.SkipNetworkMapUpdate {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -971,7 +963,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
@@ -1385,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
|
|||||||
79
client/internal/engine_sync_test.go
Normal file
79
client/internal/engine_sync_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
|
||||||
|
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
|
||||||
|
WgIfaceName: "utun199",
|
||||||
|
WgAddr: "100.70.0.1/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
// Precondition
|
||||||
|
if engine.networkSerial != 0 {
|
||||||
|
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &mgmtProto.SyncResponse{
|
||||||
|
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
|
||||||
|
SkipNetworkMapUpdate: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := engine.handleSync(resp); err != nil {
|
||||||
|
t.Fatalf("handleSync returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if engine.networkSerial != 0 {
|
||||||
|
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensures handleSync exits early when NetworkMap is nil
|
||||||
|
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
|
||||||
|
WgIfaceName: "utun198",
|
||||||
|
WgAddr: "100.70.0.2/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33101,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
|
||||||
|
|
||||||
|
if err := engine.handleSync(resp); err != nil {
|
||||||
|
t.Fatalf("handleSync returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -631,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
// feed updates to Engine via mocked Management client
|
// feed updates to Engine via mocked Management client
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
defer close(updates)
|
defer close(updates)
|
||||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
for msg := range updates {
|
for msg := range updates {
|
||||||
err := msgHandler(msg)
|
err := msgHandler(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(fd int32, interfaceName string) error {
|
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||||
|
exportEnvList(envList)
|
||||||
log.Infof("Starting NetBird client")
|
log.Infof("Starting NetBird client")
|
||||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
|
|||||||
}
|
}
|
||||||
return netIDs
|
return netIDs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func exportEnvList(list *EnvList) {
|
||||||
|
if list == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for k, v := range list.AllItems() {
|
||||||
|
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
|
||||||
|
log.Debugf("Setting env variable %s: %s", k, v)
|
||||||
|
|
||||||
|
if err := os.Setenv(k, v); err != nil {
|
||||||
|
log.Errorf("could not set env variable %s: %v", k, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Env variable %s was set successfully", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
34
client/ios/NetBirdSDK/env_list.go
Normal file
34
client/ios/NetBirdSDK/env_list.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package NetBirdSDK
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
|
||||||
|
// EnvList is an exported struct to be bound by gomobile
|
||||||
|
type EnvList struct {
|
||||||
|
data map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEnvList creates a new EnvList
|
||||||
|
func NewEnvList() *EnvList {
|
||||||
|
return &EnvList{data: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put adds a key-value pair
|
||||||
|
func (el *EnvList) Put(key, value string) {
|
||||||
|
el.data[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key
|
||||||
|
func (el *EnvList) Get(key string) string {
|
||||||
|
return el.data[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *EnvList) AllItems() map[string]string {
|
||||||
|
return el.data
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
|
||||||
|
func GetEnvKeyNBForceRelay() string {
|
||||||
|
return peer.EnvKeyNBForceRelay
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import _ "golang.org/x/mobile/bind"
|
import _ "golang.org/x/mobile/bind"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
|
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
|
||||||
|
|||||||
@@ -822,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
|||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
if err := s.cleanupConnection(); err != nil {
|
if err := s.cleanupConnection(); err != nil {
|
||||||
|
// todo review to update the status in case any type of error
|
||||||
log.Errorf("failed to shut down properly: %v", err)
|
log.Errorf("failed to shut down properly: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -914,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
||||||
|
// todo review to update the status in case any type of error
|
||||||
log.Errorf("failed to cleanup connection: %v", err)
|
log.Errorf("failed to cleanup connection: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectUtilLinuxLogin always returns false on JS/WASM
|
||||||
|
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// executeCommandWithPty is not supported on JS/WASM
|
// executeCommandWithPty is not supported on JS/WASM
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
logger.Errorf("PTY command execution not supported on JS/WASM")
|
logger.Errorf("PTY command execution not supported on JS/WASM")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool {
|
|||||||
return supported
|
return supported
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
|
||||||
|
// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
|
||||||
|
// See https://bugs.debian.org/1078023 for details.
|
||||||
|
func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "login", "--version")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("login --version failed (likely shadow-utils): %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
isUtilLinux := strings.Contains(string(output), "util-linux")
|
||||||
|
log.Debugf("util-linux login detected: %v", isUtilLinux)
|
||||||
|
return isUtilLinux
|
||||||
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching
|
// createSuCommand creates a command using su -l -c for privilege switching
|
||||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||||
suPath, err := exec.LookPath("su")
|
suPath, err := exec.LookPath("su")
|
||||||
@@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("starting interactive shell: %s", execCmd.Path)
|
logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
|
||||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectUtilLinuxLogin always returns false on Windows
|
||||||
|
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
command := session.RawCommand()
|
command := session.RawCommand()
|
||||||
|
|||||||
@@ -138,7 +138,8 @@ type Server struct {
|
|||||||
jwtExtractor *jwt.ClaimsExtractor
|
jwtExtractor *jwt.ClaimsExtractor
|
||||||
jwtConfig *JWTConfig
|
jwtConfig *JWTConfig
|
||||||
|
|
||||||
suSupportsPty bool
|
suSupportsPty bool
|
||||||
|
loginIsUtilLinux bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWTConfig struct {
|
type JWTConfig struct {
|
||||||
@@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.suSupportsPty = s.detectSuPtySupport(ctx)
|
s.suSupportsPty = s.detectSuPtySupport(ctx)
|
||||||
|
s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
|
||||||
|
|
||||||
ln, addrDesc, err := s.createListener(ctx, addr)
|
ln, addrDesc, err := s.createListener(ctx, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
|
|||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "linux":
|
case "linux":
|
||||||
// Special handling for Arch Linux without /etc/pam.d/remote
|
p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
|
||||||
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
return p, a, nil
|
||||||
return loginPath, []string{"-f", username, "-p"}, nil
|
|
||||||
}
|
|
||||||
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
|
|
||||||
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
|
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
|
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
|
||||||
default:
|
default:
|
||||||
@@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fileExists checks if a file exists (helper for login command logic)
|
// getLinuxLoginCmd returns the login command for Linux systems.
|
||||||
|
// Handles differences between util-linux and shadow-utils login implementations.
|
||||||
|
func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
|
||||||
|
// Special handling for Arch Linux without /etc/pam.d/remote
|
||||||
|
var loginArgs []string
|
||||||
|
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
||||||
|
loginArgs = []string{"-f", username, "-p"}
|
||||||
|
} else {
|
||||||
|
loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// util-linux login requires setsid -c to create a new session and set the
|
||||||
|
// controlling terminal. Without this, vhangup() kills the parent process.
|
||||||
|
// See https://bugs.debian.org/1078023 for details.
|
||||||
|
// TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
|
||||||
|
// to avoid external setsid dependency.
|
||||||
|
if !s.loginIsUtilLinux {
|
||||||
|
return loginPath, loginArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
setsidPath, err := exec.LookPath("setsid")
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
|
||||||
|
return loginPath, loginArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
args := append([]string{"-w", "-c", loginPath}, loginArgs...)
|
||||||
|
return setsidPath, args
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileExists checks if a file exists
|
||||||
func (s *Server) fileExists(path string) bool {
|
func (s *Server) fileExists(path string) bool {
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
return err == nil
|
return err == nil
|
||||||
|
|||||||
@@ -120,6 +120,26 @@ func (i *Info) SetFlags(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *Info) CopyFlagsFrom(other *Info) {
|
||||||
|
i.SetFlags(
|
||||||
|
other.RosenpassEnabled,
|
||||||
|
other.RosenpassPermissive,
|
||||||
|
&other.ServerSSHAllowed,
|
||||||
|
other.DisableClientRoutes,
|
||||||
|
other.DisableServerRoutes,
|
||||||
|
other.DisableDNS,
|
||||||
|
other.DisableFirewall,
|
||||||
|
other.BlockLANAccess,
|
||||||
|
other.BlockInbound,
|
||||||
|
other.LazyConnectionEnabled,
|
||||||
|
&other.EnableSSHRoot,
|
||||||
|
&other.EnableSSHSFTP,
|
||||||
|
&other.EnableSSHLocalPortForwarding,
|
||||||
|
&other.EnableSSHRemotePortForwarding,
|
||||||
|
&other.DisableSSHAuth,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
func extractUserAgent(ctx context.Context) string {
|
func extractUserAgent(ctx context.Context) string {
|
||||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||||
|
|||||||
@@ -8,6 +8,90 @@ import (
|
|||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestInfo_CopyFlagsFrom(t *testing.T) {
|
||||||
|
origin := &Info{}
|
||||||
|
serverSSHAllowed := true
|
||||||
|
enableSSHRoot := true
|
||||||
|
enableSSHSFTP := false
|
||||||
|
enableSSHLocalPortForwarding := true
|
||||||
|
enableSSHRemotePortForwarding := false
|
||||||
|
disableSSHAuth := true
|
||||||
|
origin.SetFlags(
|
||||||
|
true, // RosenpassEnabled
|
||||||
|
false, // RosenpassPermissive
|
||||||
|
&serverSSHAllowed,
|
||||||
|
true, // DisableClientRoutes
|
||||||
|
false, // DisableServerRoutes
|
||||||
|
true, // DisableDNS
|
||||||
|
false, // DisableFirewall
|
||||||
|
true, // BlockLANAccess
|
||||||
|
false, // BlockInbound
|
||||||
|
true, // LazyConnectionEnabled
|
||||||
|
&enableSSHRoot,
|
||||||
|
&enableSSHSFTP,
|
||||||
|
&enableSSHLocalPortForwarding,
|
||||||
|
&enableSSHRemotePortForwarding,
|
||||||
|
&disableSSHAuth,
|
||||||
|
)
|
||||||
|
|
||||||
|
got := &Info{}
|
||||||
|
got.CopyFlagsFrom(origin)
|
||||||
|
|
||||||
|
if got.RosenpassEnabled != true {
|
||||||
|
t.Fatalf("RosenpassEnabled not copied: got %v", got.RosenpassEnabled)
|
||||||
|
}
|
||||||
|
if got.RosenpassPermissive != false {
|
||||||
|
t.Fatalf("RosenpassPermissive not copied: got %v", got.RosenpassPermissive)
|
||||||
|
}
|
||||||
|
if got.ServerSSHAllowed != true {
|
||||||
|
t.Fatalf("ServerSSHAllowed not copied: got %v", got.ServerSSHAllowed)
|
||||||
|
}
|
||||||
|
if got.DisableClientRoutes != true {
|
||||||
|
t.Fatalf("DisableClientRoutes not copied: got %v", got.DisableClientRoutes)
|
||||||
|
}
|
||||||
|
if got.DisableServerRoutes != false {
|
||||||
|
t.Fatalf("DisableServerRoutes not copied: got %v", got.DisableServerRoutes)
|
||||||
|
}
|
||||||
|
if got.DisableDNS != true {
|
||||||
|
t.Fatalf("DisableDNS not copied: got %v", got.DisableDNS)
|
||||||
|
}
|
||||||
|
if got.DisableFirewall != false {
|
||||||
|
t.Fatalf("DisableFirewall not copied: got %v", got.DisableFirewall)
|
||||||
|
}
|
||||||
|
if got.BlockLANAccess != true {
|
||||||
|
t.Fatalf("BlockLANAccess not copied: got %v", got.BlockLANAccess)
|
||||||
|
}
|
||||||
|
if got.BlockInbound != false {
|
||||||
|
t.Fatalf("BlockInbound not copied: got %v", got.BlockInbound)
|
||||||
|
}
|
||||||
|
if got.LazyConnectionEnabled != true {
|
||||||
|
t.Fatalf("LazyConnectionEnabled not copied: got %v", got.LazyConnectionEnabled)
|
||||||
|
}
|
||||||
|
if got.EnableSSHRoot != true {
|
||||||
|
t.Fatalf("EnableSSHRoot not copied: got %v", got.EnableSSHRoot)
|
||||||
|
}
|
||||||
|
if got.EnableSSHSFTP != false {
|
||||||
|
t.Fatalf("EnableSSHSFTP not copied: got %v", got.EnableSSHSFTP)
|
||||||
|
}
|
||||||
|
if got.EnableSSHLocalPortForwarding != true {
|
||||||
|
t.Fatalf("EnableSSHLocalPortForwarding not copied: got %v", got.EnableSSHLocalPortForwarding)
|
||||||
|
}
|
||||||
|
if got.EnableSSHRemotePortForwarding != false {
|
||||||
|
t.Fatalf("EnableSSHRemotePortForwarding not copied: got %v", got.EnableSSHRemotePortForwarding)
|
||||||
|
}
|
||||||
|
if got.DisableSSHAuth != true {
|
||||||
|
t.Fatalf("DisableSSHAuth not copied: got %v", got.DisableSSHAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure CopyFlagsFrom does not touch unrelated fields
|
||||||
|
origin.Hostname = "host-a"
|
||||||
|
got.Hostname = "host-b"
|
||||||
|
got.CopyFlagsFrom(origin)
|
||||||
|
if got.Hostname != "host-b" {
|
||||||
|
t.Fatalf("CopyFlagsFrom should not overwrite non-flag fields, got Hostname=%q", got.Hostname)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_LocalWTVersion(t *testing.T) {
|
func Test_LocalWTVersion(t *testing.T) {
|
||||||
got := GetInfo(context.TODO())
|
got := GetInfo(context.TODO())
|
||||||
want := "development"
|
want := "development"
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -64,7 +64,7 @@ require (
|
|||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
|
|||||||
@@ -394,23 +394,26 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
if isRequiresApproval {
|
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, nil, nil, 0, err
|
||||||
return nil, nil, nil, 0, err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
if isRequiresApproval {
|
||||||
emptyMap := &types.NetworkMap{
|
emptyMap := &types.NetworkMap{
|
||||||
Network: network.Copy(),
|
Network: network.Copy(),
|
||||||
}
|
}
|
||||||
return peer, emptyMap, nil, 0, nil
|
return peer, emptyMap, nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
if clientSerial > 0 && clientSerial == network.CurrentSerial() {
|
||||||
account *types.Account
|
log.WithContext(ctx).Debugf("client serial %d matches current serial, skipping network map calculation", clientSerial)
|
||||||
err error
|
return peer, nil, nil, 0, nil
|
||||||
)
|
}
|
||||||
|
|
||||||
|
var account *types.Account
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
account = c.getAccountFromHolderOrInit(accountID)
|
account = c.getAccountFromHolderOrInit(accountID)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type Controller interface {
|
|||||||
UpdateAccountPeers(ctx context.Context, accountID string) error
|
UpdateAccountPeers(ctx context.Context, accountID string) error
|
||||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||||
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
||||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
GetDNSDomain(settings *types.Settings) string
|
GetDNSDomain(settings *types.Settings) string
|
||||||
StartWarmup(context.Context)
|
StartWarmup(context.Context)
|
||||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||||
|
|||||||
@@ -113,9 +113,9 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetValidatedPeerWithMap mocks base method.
|
// GetValidatedPeerWithMap mocks base method.
|
||||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer, clientSerial uint64) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
|
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p, clientSerial)
|
||||||
ret0, _ := ret[0].(*peer.Peer)
|
ret0, _ := ret[0].(*peer.Peer)
|
||||||
ret1, _ := ret[1].(*types.NetworkMap)
|
ret1, _ := ret[1].(*types.NetworkMap)
|
||||||
ret2, _ := ret[2].([]*posture.Checks)
|
ret2, _ := ret[2].([]*posture.Checks)
|
||||||
@@ -125,9 +125,9 @@ func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequires
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
||||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p, clientSerial any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p, clientSerial)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPeerConnected mocks base method.
|
// OnPeerConnected mocks base method.
|
||||||
|
|||||||
@@ -104,6 +104,20 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToSkipSyncResponse creates a minimal SyncResponse when the client already has the latest network map.
|
||||||
|
func ToSkipSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, checks []*posture.Checks, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
|
||||||
|
response := &proto.SyncResponse{
|
||||||
|
SkipNetworkMapUpdate: true,
|
||||||
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
|
}
|
||||||
|
|
||||||
|
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||||
|
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||||
|
response.NetbirdConfig = extendedConfig
|
||||||
|
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
|
||||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||||
response := &proto.SyncResponse{
|
response := &proto.SyncResponse{
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
||||||
|
|||||||
@@ -134,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
|
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
|
|
||||||
}()
|
|
||||||
|
|
||||||
// todo introduce something more meaningful with the key expiration/rotation
|
// todo introduce something more meaningful with the key expiration/rotation
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
@@ -194,7 +190,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||||
}
|
}
|
||||||
if s.logBlockedPeers {
|
if s.logBlockedPeers {
|
||||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||||
}
|
}
|
||||||
if s.blockPeersWithSameConfig {
|
if s.blockPeersWithSameConfig {
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -222,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
|
||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||||
|
|
||||||
@@ -235,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
||||||
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||||
|
|
||||||
@@ -246,7 +239,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
metahash := metaHash(peerMeta, realIP.String())
|
metahash := metaHash(peerMeta, realIP.String())
|
||||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||||
|
|
||||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncReq.GetNetworkMapSerial())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -352,7 +345,7 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
|
|||||||
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
|
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
|
||||||
s.secretsManager.CancelRefresh(peer.ID)
|
s.secretsManager.CancelRefresh(peer.ID)
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||||
@@ -525,7 +518,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
reqStart := time.Now()
|
reqStart := time.Now()
|
||||||
realIP := getRealIP(ctx)
|
realIP := getRealIP(ctx)
|
||||||
sRealIP := realIP.String()
|
sRealIP := realIP.String()
|
||||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
|
||||||
|
|
||||||
loginReq := &proto.LoginRequest{}
|
loginReq := &proto.LoginRequest{}
|
||||||
peerKey, err := s.parseRequest(ctx, req, loginReq)
|
peerKey, err := s.parseRequest(ctx, req, loginReq)
|
||||||
@@ -537,7 +529,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
metahashed := metaHash(peerMeta, sRealIP)
|
metahashed := metaHash(peerMeta, sRealIP)
|
||||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||||
if s.logBlockedPeers {
|
if s.logBlockedPeers {
|
||||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||||
}
|
}
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
|
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
|
||||||
@@ -561,16 +553,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
//nolint
|
//nolint
|
||||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||||
}
|
}
|
||||||
took := time.Since(reqStart)
|
|
||||||
if took > 7*time.Second {
|
|
||||||
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if loginReq.GetMeta() == nil {
|
if loginReq.GetMeta() == nil {
|
||||||
@@ -604,16 +592,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
return nil, mapError(ctx, err)
|
return nil, mapError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
|
|
||||||
|
|
||||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
||||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
|
|
||||||
|
|
||||||
key, err := s.secretsManager.GetWGKey()
|
key, err := s.secretsManager.GetWGKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
|
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
|
||||||
@@ -718,7 +702,12 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
|||||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
var plainResp *proto.SyncResponse
|
||||||
|
if networkMap == nil {
|
||||||
|
plainResp = ToSkipSyncResponse(ctx, s.config, peer, turnToken, relayToken, postureChecks, settings.Extra, peerGroups)
|
||||||
|
} else {
|
||||||
|
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||||
|
}
|
||||||
|
|
||||||
key, err := s.secretsManager.GetWGKey()
|
key, err := s.secretsManager.GetWGKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -730,12 +719,10 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
|||||||
return status.Errorf(codes.Internal, "error handling request")
|
return status.Errorf(codes.Internal, "error handling request")
|
||||||
}
|
}
|
||||||
|
|
||||||
sendStart := time.Now()
|
|
||||||
err = srv.Send(&proto.EncryptedMessage{
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
WgPubKey: key.PublicKey().String(),
|
WgPubKey: key.PublicKey().String(),
|
||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
})
|
})
|
||||||
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
||||||
@@ -750,10 +737,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
|||||||
// which will be used by our clients to Login
|
// which will be used by our clients to Login
|
||||||
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
|
|
||||||
}()
|
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -813,10 +796,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
|
|||||||
// which will be used by our clients to Login
|
// which will be used by our clients to Login
|
||||||
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
|
|
||||||
}()
|
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI
|
|||||||
relayCancel := make(chan struct{}, 1)
|
relayCancel := make(chan struct{}, 1)
|
||||||
m.relayCancelMap[peerID] = relayCancel
|
m.relayCancelMap[peerID] = relayCancel
|
||||||
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
|
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
|
||||||
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
|
log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-cancel:
|
case <-cancel:
|
||||||
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
|
log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
|
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
|
||||||
@@ -193,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-cancel:
|
case <-cancel:
|
||||||
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
|
log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
m.pushNewRelayTokens(ctx, accountID, peerID)
|
m.pushNewRelayTokens(ctx, accountID, peerID)
|
||||||
|
|||||||
@@ -295,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
|
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if oldSettings.Extra != nil && newSettings.Extra != nil &&
|
||||||
|
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
|
||||||
|
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to approve pending peers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if approvedCount > 0 {
|
||||||
|
log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID)
|
||||||
|
updateAccountPeers = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -372,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return newSettings, nil
|
return newSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||||
halfYearLimit := 180 * 24 * time.Hour
|
halfYearLimit := 180 * 24 * time.Hour
|
||||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||||
@@ -386,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
|
|||||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
|
||||||
for _, peer := range peers {
|
|
||||||
peersMap[peer.ID] = peer
|
|
||||||
}
|
|
||||||
|
|
||||||
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||||
@@ -787,6 +790,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
|||||||
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
||||||
accountIDString := fmt.Sprintf("%v", accountID)
|
accountIDString := fmt.Sprintf("%v", accountID)
|
||||||
|
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:staticcheck
|
||||||
|
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||||
|
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -1607,8 +1617,8 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
|||||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, NetworkMapSerial: clientSerial}, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ type Manager interface {
|
|||||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
|
|||||||
@@ -2058,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) {
|
||||||
|
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
accountID := account.Id
|
||||||
|
userID := account.Users[account.CreatedBy].Id
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
newSettings := account.Settings.Copy()
|
||||||
|
newSettings.Extra = &types.ExtraSettings{
|
||||||
|
PeerApprovalEnabled: true,
|
||||||
|
}
|
||||||
|
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer1.Status.RequiresApproval = true
|
||||||
|
peer2.Status.RequiresApproval = true
|
||||||
|
peer3.Status.RequiresApproval = false
|
||||||
|
|
||||||
|
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1))
|
||||||
|
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2))
|
||||||
|
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3))
|
||||||
|
|
||||||
|
newSettings = account.Settings.Copy()
|
||||||
|
newSettings.Extra = &types.ExtraSettings{
|
||||||
|
PeerApprovalEnabled: false,
|
||||||
|
}
|
||||||
|
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, peer := range accountPeers {
|
||||||
|
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_GetExpiredPeers(t *testing.T) {
|
func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
@@ -3107,7 +3144,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, 0)
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userAuth.AccountId != accountId {
|
if userAuth.AccountId != accountId {
|
||||||
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||||
userAuth.AccountId = accountId
|
userAuth.AccountId = accountId
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ type MockIntegratedValidator struct {
|
|||||||
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
// IntegratedValidator interface exists to avoid the circle dependencies
|
// IntegratedValidator interface exists to avoid the circle dependencies
|
||||||
type IntegratedValidator interface {
|
type IntegratedValidator interface {
|
||||||
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error
|
||||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
|
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
|
||||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ type MockAccountManager struct {
|
|||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||||
@@ -177,9 +177,9 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
|
|||||||
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
if am.SyncAndMarkPeerFunc != nil {
|
if am.SyncAndMarkPeerFunc != nil {
|
||||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, clientSerial)
|
||||||
}
|
}
|
||||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||||
|
|
||||||
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
|
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -645,7 +645,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer, 0)
|
||||||
return p, nmap, pc, err
|
return p, nmap, pc, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -731,7 +731,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
|
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer, sync.NetworkMapSerial)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||||
@@ -783,7 +783,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
startTransaction := time.Now()
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
|
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -853,8 +852,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
|
|
||||||
|
|
||||||
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -862,7 +859,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
|
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer, 0)
|
||||||
return p, nmap, pc, err
|
return p, nmap, pc, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error {
|
|||||||
|
|
||||||
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
|
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
|
||||||
if check == nil {
|
if check == nil {
|
||||||
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
|
log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
|
|||||||
|
|
||||||
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
|
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
|
||||||
if check == nil {
|
if check == nil {
|
||||||
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
|
log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,6 @@ package settings
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
|
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
|
||||||
@@ -48,11 +45,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
|
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
|
|
||||||
}()
|
|
||||||
|
|
||||||
if userID != activity.SystemInitiator {
|
if userID != activity.SystemInitiator {
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ import (
|
|||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -288,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
|
|||||||
if s.metrics != nil {
|
if s.metrics != nil {
|
||||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds())
|
log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds())
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -413,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
|
||||||
|
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
|
||||||
|
result := s.db.Model(&nbpeer.Peer{}).
|
||||||
|
Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true).
|
||||||
|
Update("peer_status_requires_approval", false)
|
||||||
|
if result.Error != nil {
|
||||||
|
return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(result.RowsAffected), nil
|
||||||
|
}
|
||||||
|
|
||||||
// SaveUsers saves the given list of users to the database.
|
// SaveUsers saves the given list of users to the database.
|
||||||
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
|
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
|
||||||
if len(users) == 0 {
|
if len(users) == 0 {
|
||||||
@@ -583,16 +594,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var user types.User
|
var user types.User
|
||||||
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
|
result := tx.Take(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
@@ -2152,16 +2160,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var accountNetwork types.AccountNetwork
|
var accountNetwork types.AccountNetwork
|
||||||
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewAccountNotFoundError(accountID)
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
@@ -2171,16 +2176,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
|
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
@@ -2229,11 +2231,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
|
|||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var user types.User
|
var user types.User
|
||||||
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
|
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.NewUserNotFoundError(userID)
|
return status.NewUserNotFoundError(userID)
|
||||||
@@ -2491,16 +2490,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var setupKey types.SetupKey
|
var setupKey types.SetupKey
|
||||||
result := tx.WithContext(ctx).
|
result := tx.
|
||||||
Take(&setupKey, GetKeyQueryCondition(s), key)
|
Take(&setupKey, GetKeyQueryCondition(s), key)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@@ -2514,10 +2510,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
result := s.db.Model(&types.SetupKey{}).
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
|
|
||||||
Where(idQueryCondition, setupKeyID).
|
Where(idQueryCondition, setupKeyID).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]interface{}{
|
||||||
"used_times": gorm.Expr("used_times + 1"),
|
"used_times": gorm.Expr("used_times + 1"),
|
||||||
@@ -2537,11 +2530,8 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
|||||||
|
|
||||||
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
||||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var groupID string
|
var groupID string
|
||||||
_ = s.db.WithContext(ctx).Model(types.Group{}).
|
_ = s.db.Model(types.Group{}).
|
||||||
Select("id").
|
Select("id").
|
||||||
Where("account_id = ? AND name = ?", accountID, "All").
|
Where("account_id = ? AND name = ?", accountID, "All").
|
||||||
Limit(1).
|
Limit(1).
|
||||||
@@ -2569,9 +2559,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
|||||||
|
|
||||||
// AddPeerToGroup adds a peer to a group
|
// AddPeerToGroup adds a peer to a group
|
||||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
|
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
peer := &types.GroupPeer{
|
peer := &types.GroupPeer{
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
@@ -2768,10 +2755,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
if err := s.db.Create(peer).Error; err != nil {
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
|
||||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2897,10 +2881,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||||
ctx, cancel := getDebuggingCtx(ctx)
|
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||||
@@ -4022,36 +4003,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
|
|||||||
return groupPeers, nil
|
return groupPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
|
||||||
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
|
|
||||||
if ok {
|
|
||||||
//nolint
|
|
||||||
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
|
|
||||||
if ok {
|
|
||||||
//nolint
|
|
||||||
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
|
|
||||||
if ok {
|
|
||||||
//nolint
|
|
||||||
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
case <-grpcCtx.Done():
|
|
||||||
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return ctx, cancel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
||||||
var info types.PrimaryAccountInfo
|
var info types.PrimaryAccountInfo
|
||||||
result := s.db.Model(&types.Account{}).
|
result := s.db.Model(&types.Account{}).
|
||||||
@@ -4091,7 +4042,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
|
|||||||
Network: &types.Network{Net: ipNet},
|
Network: &types.Network{Net: ipNet},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).
|
result := s.db.
|
||||||
Model(&types.Account{}).
|
Model(&types.Account{}).
|
||||||
Where(idQueryCondition, accountID).
|
Where(idQueryCondition, accountID).
|
||||||
Updates(&patch)
|
Updates(&patch)
|
||||||
|
|||||||
@@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
|
||||||
|
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||||
|
accountID := "test-account"
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
account := newAccountWithId(ctx, accountID, "testuser", "example.com")
|
||||||
|
err := store.SaveAccount(ctx, account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peers := []*nbpeer.Peer{
|
||||||
|
{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: accountID,
|
||||||
|
DNSLabel: "peer1.netbird.cloud",
|
||||||
|
Key: "peer1-key",
|
||||||
|
IP: net.ParseIP("100.64.0.1"),
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
RequiresApproval: true,
|
||||||
|
LastSeen: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "peer2",
|
||||||
|
AccountID: accountID,
|
||||||
|
DNSLabel: "peer2.netbird.cloud",
|
||||||
|
Key: "peer2-key",
|
||||||
|
IP: net.ParseIP("100.64.0.2"),
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
RequiresApproval: true,
|
||||||
|
LastSeen: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "peer3",
|
||||||
|
AccountID: accountID,
|
||||||
|
DNSLabel: "peer3.netbird.cloud",
|
||||||
|
Key: "peer3-key",
|
||||||
|
IP: net.ParseIP("100.64.0.3"),
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
RequiresApproval: false,
|
||||||
|
LastSeen: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
err = store.AddPeerToAccount(ctx, peer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("approve all pending peers", func(t *testing.T) {
|
||||||
|
count, err := store.ApproveAccountPeers(ctx, accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
|
||||||
|
allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, peer := range allPeers {
|
||||||
|
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no peers to approve", func(t *testing.T) {
|
||||||
|
count, err := store.ApproveAccountPeers(ctx, accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-existent account", func(t *testing.T) {
|
||||||
|
count, err := store.ApproveAccountPeers(ctx, "non-existent")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ type Store interface {
|
|||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
|
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
|
||||||
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
||||||
|
|
||||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ type GRPCMetrics struct {
|
|||||||
meter metric.Meter
|
meter metric.Meter
|
||||||
syncRequestsCounter metric.Int64Counter
|
syncRequestsCounter metric.Int64Counter
|
||||||
syncRequestsBlockedCounter metric.Int64Counter
|
syncRequestsBlockedCounter metric.Int64Counter
|
||||||
syncRequestHighLatencyCounter metric.Int64Counter
|
|
||||||
loginRequestsCounter metric.Int64Counter
|
loginRequestsCounter metric.Int64Counter
|
||||||
loginRequestsBlockedCounter metric.Int64Counter
|
loginRequestsBlockedCounter metric.Int64Counter
|
||||||
loginRequestHighLatencyCounter metric.Int64Counter
|
loginRequestHighLatencyCounter metric.Int64Counter
|
||||||
@@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
|
|
||||||
metric.WithUnit("1"),
|
|
||||||
metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
|
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
|
||||||
metric.WithUnit("1"),
|
metric.WithUnit("1"),
|
||||||
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
|
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
|
||||||
@@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
|||||||
meter: meter,
|
meter: meter,
|
||||||
syncRequestsCounter: syncRequestsCounter,
|
syncRequestsCounter: syncRequestsCounter,
|
||||||
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
||||||
syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
|
|
||||||
loginRequestsCounter: loginRequestsCounter,
|
loginRequestsCounter: loginRequestsCounter,
|
||||||
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
||||||
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
||||||
@@ -175,9 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
|
|||||||
// CountSyncRequestDuration counts the duration of the sync gRPC requests
|
// CountSyncRequestDuration counts the duration of the sync gRPC requests
|
||||||
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
|
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
|
||||||
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||||
if duration > HighLatencyThreshold {
|
|
||||||
grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
|
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
|
||||||
|
|||||||
@@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
|
|
||||||
h.ServeHTTP(w, r.WithContext(ctx))
|
h.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
|
||||||
|
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err == nil {
|
||||||
|
if userAuth.AccountId != "" {
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
|
||||||
|
}
|
||||||
|
if userAuth.UserId != "" {
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if w.Status() > 399 {
|
if w.Status() > 399 {
|
||||||
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ type PeerSync struct {
|
|||||||
// UpdateAccountPeers indicate updating account peers,
|
// UpdateAccountPeers indicate updating account peers,
|
||||||
// which occurs when the peer's metadata is updated
|
// which occurs when the peer's metadata is updated
|
||||||
UpdateAccountPeers bool
|
UpdateAccountPeers bool
|
||||||
|
// NetworkMapSerial is the last known network map serial number on the client.
|
||||||
|
// Used to skip network map recalculation if client already has the latest.
|
||||||
|
NetworkMapSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerLogin used as a data object between the gRPC API and Manager on Login request.
|
// PeerLogin used as a data object between the gRPC API and Manager on Login request.
|
||||||
|
|||||||
31
relay/healthcheck/peerid/peerid.go
Normal file
31
relay/healthcheck/peerid/peerid.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package peerid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
|
||||||
|
v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
|
||||||
|
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// HealthCheckPeerID is the hashed peer ID for health check connections
|
||||||
|
HealthCheckPeerID = messages.HashID("healthcheck-agent")
|
||||||
|
|
||||||
|
// DummyAuthToken is a structurally valid auth token for health check.
|
||||||
|
// The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload).
|
||||||
|
DummyAuthToken = createDummyToken()
|
||||||
|
)
|
||||||
|
|
||||||
|
func createDummyToken() []byte {
|
||||||
|
token := v2.Token{
|
||||||
|
AuthAlgo: v2.AuthAlgoHMACSHA256,
|
||||||
|
Signature: make([]byte, sha256.Size),
|
||||||
|
Payload: []byte("healthcheck"),
|
||||||
|
}
|
||||||
|
return token.Marshal()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHealthCheck checks if the given peer ID is the health check agent
|
||||||
|
func IsHealthCheck(peerID *messages.PeerID) bool {
|
||||||
|
return peerID != nil && *peerID == HealthCheckPeerID
|
||||||
|
}
|
||||||
@@ -7,8 +7,10 @@ import (
|
|||||||
|
|
||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
|
||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
"github.com/netbirdio/netbird/shared/relay"
|
"github.com/netbirdio/netbird/shared/relay"
|
||||||
|
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
func dialWS(ctx context.Context, address url.URL) error {
|
func dialWS(ctx context.Context, address url.URL) error {
|
||||||
@@ -30,7 +32,18 @@ func dialWS(ctx context.Context, address url.URL) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to websocket: %w", err)
|
return fmt.Errorf("failed to connect to websocket: %w", err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal auth message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil {
|
||||||
|
return fmt.Errorf("failed to write auth message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
|
|||||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return peerID, err
|
||||||
}
|
}
|
||||||
h.peerID = peerID
|
h.peerID = peerID
|
||||||
return peerID, nil
|
return peerID, nil
|
||||||
@@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.validator.Validate(authPayload); err != nil {
|
if err := h.validator.Validate(authPayload); err != nil {
|
||||||
return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
|
return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawPeerID, nil
|
return rawPeerID, nil
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
"github.com/netbirdio/netbird/relay/server/store"
|
"github.com/netbirdio/netbird/relay/server/store"
|
||||||
@@ -123,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
peerID, err := h.handshakeReceive()
|
peerID, err := h.handshakeReceive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to handshake: %s", err)
|
if peerid.IsHealthCheck(peerID) {
|
||||||
|
log.Debugf("health check connection from %s", conn.RemoteAddr())
|
||||||
|
} else {
|
||||||
|
log.Errorf("failed to handshake: %s", err)
|
||||||
|
}
|
||||||
if cErr := conn.Close(); cErr != nil {
|
if cErr := conn.Close(); cErr != nil {
|
||||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
Sync(ctx context.Context, sysInfo *system.Info, networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKey() (*wgtypes.Key, error)
|
GetServerPublicKey() (*wgtypes.Key, error)
|
||||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err = client.Sync(ctx, info, func(msg *mgmtProto.SyncResponse) error {
|
err = client.Sync(ctx, info, 0, func(msg *mgmtProto.SyncResponse) error {
|
||||||
ch <- msg
|
ch <- msg
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -6,14 +6,14 @@ package common
|
|||||||
//
|
//
|
||||||
// | Value | Flag | OAuth Parameters |
|
// | Value | Flag | OAuth Parameters |
|
||||||
// |-------|----------------------|-----------------------------------------|
|
// |-------|----------------------|-----------------------------------------|
|
||||||
// | 0 | LoginFlagPromptLogin | prompt=select_account login |
|
// | 0 | LoginFlagPromptLogin | prompt=login |
|
||||||
// | 1 | LoginFlagMaxAge0 | max_age=0 & prompt=select_account |
|
// | 1 | LoginFlagMaxAge0 | max_age=0 |
|
||||||
type LoginFlag uint8
|
type LoginFlag uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// LoginFlagPromptLogin adds prompt=select_account login to the authorization request
|
// LoginFlagPromptLogin adds prompt=login to the authorization request
|
||||||
LoginFlagPromptLogin LoginFlag = iota
|
LoginFlagPromptLogin LoginFlag = iota
|
||||||
// LoginFlagMaxAge0 adds max_age=0 and prompt=select_account to the authorization request
|
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
|
||||||
LoginFlagMaxAge0
|
LoginFlagMaxAge0
|
||||||
// LoginFlagNone disables all login flags
|
// LoginFlagNone disables all login flags
|
||||||
LoginFlagNone
|
LoginFlagNone
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ func (c *GrpcClient) ready() bool {
|
|||||||
|
|
||||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||||
// Blocking request. The result will be sent via msgHandler callback function
|
// Blocking request. The result will be sent via msgHandler callback function
|
||||||
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
log.Debugf("management connection state %v", c.conn.GetState())
|
log.Debugf("management connection state %v", c.conn.GetState())
|
||||||
connState := c.conn.GetState()
|
connState := c.conn.GetState()
|
||||||
@@ -128,7 +128,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
|
return c.handleStream(ctx, *serverPubKey, sysInfo, networkSerial, msgHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := backoff.Retry(operation, defaultBackoff(ctx))
|
err := backoff.Retry(operation, defaultBackoff(ctx))
|
||||||
@@ -140,11 +140,11 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
|
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
|
||||||
msgHandler func(msg *proto.SyncResponse) error) error {
|
networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
ctx, cancelStream := context.WithCancel(ctx)
|
ctx, cancelStream := context.WithCancel(ctx)
|
||||||
defer cancelStream()
|
defer cancelStream()
|
||||||
|
|
||||||
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo)
|
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo, networkSerial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to open Management Service stream: %s", err)
|
log.Debugf("failed to open Management Service stream: %s", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||||
@@ -186,7 +186,8 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
|
|||||||
|
|
||||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||||
defer cancelStream()
|
defer cancelStream()
|
||||||
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
|
// GetNetworkMap doesn't have a serial to send, so we pass 0
|
||||||
|
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to open Management Service stream: %s", err)
|
log.Debugf("failed to open Management Service stream: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -219,8 +220,17 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
|
|||||||
return decryptedResp.GetNetworkMap(), nil
|
return decryptedResp.GetNetworkMap(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
|
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, networkSerial uint64) (proto.ManagementService_SyncClient, error) {
|
||||||
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
|
// Always compute latest system info to ensure up-to-date PeerSystemMeta on first and subsequent syncs
|
||||||
|
recomputed := system.GetInfo(c.ctx)
|
||||||
|
if sysInfo != nil {
|
||||||
|
recomputed.CopyFlagsFrom(sysInfo)
|
||||||
|
// carry over posture files if any were computed
|
||||||
|
if len(sysInfo.Files) > 0 {
|
||||||
|
recomputed.Files = sysInfo.Files
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req := &proto.SyncRequest{Meta: infoToMetaData(recomputed), NetworkMapSerial: networkSerial}
|
||||||
|
|
||||||
myPrivateKey := c.key
|
myPrivateKey := c.key
|
||||||
myPublicKey := myPrivateKey.PublicKey()
|
myPublicKey := myPrivateKey.PublicKey()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type MockClient struct {
|
type MockClient struct {
|
||||||
CloseFunc func() error
|
CloseFunc func() error
|
||||||
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
SyncFunc func(ctx context.Context, sysInfo *system.Info, networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
@@ -33,11 +33,11 @@ func (m *MockClient) Close() error {
|
|||||||
return m.CloseFunc()
|
return m.CloseFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, networkSerial uint64, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
if m.SyncFunc == nil {
|
if m.SyncFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.SyncFunc(ctx, sysInfo, msgHandler)
|
return m.SyncFunc(ctx, sysInfo, networkSerial, msgHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||||
|
|||||||
@@ -7,12 +7,13 @@
|
|||||||
package proto
|
package proto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
sync "sync"
|
||||||
|
|
||||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
durationpb "google.golang.org/protobuf/types/known/durationpb"
|
durationpb "google.golang.org/protobuf/types/known/durationpb"
|
||||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||||
reflect "reflect"
|
|
||||||
sync "sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -343,6 +344,8 @@ type SyncRequest struct {
|
|||||||
|
|
||||||
// Meta data of the peer
|
// Meta data of the peer
|
||||||
Meta *PeerSystemMeta `protobuf:"bytes,1,opt,name=meta,proto3" json:"meta,omitempty"`
|
Meta *PeerSystemMeta `protobuf:"bytes,1,opt,name=meta,proto3" json:"meta,omitempty"`
|
||||||
|
// Optional: last known NetworkMap serial number on the client
|
||||||
|
NetworkMapSerial uint64 `protobuf:"varint,2,opt,name=networkMapSerial,proto3" json:"networkMapSerial,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *SyncRequest) Reset() {
|
func (x *SyncRequest) Reset() {
|
||||||
@@ -384,6 +387,13 @@ func (x *SyncRequest) GetMeta() *PeerSystemMeta {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *SyncRequest) GetNetworkMapSerial() uint64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.NetworkMapSerial
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
|
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
|
||||||
type SyncResponse struct {
|
type SyncResponse struct {
|
||||||
state protoimpl.MessageState
|
state protoimpl.MessageState
|
||||||
@@ -402,6 +412,8 @@ type SyncResponse struct {
|
|||||||
NetworkMap *NetworkMap `protobuf:"bytes,5,opt,name=NetworkMap,proto3" json:"NetworkMap,omitempty"`
|
NetworkMap *NetworkMap `protobuf:"bytes,5,opt,name=NetworkMap,proto3" json:"NetworkMap,omitempty"`
|
||||||
// Posture checks to be evaluated by client
|
// Posture checks to be evaluated by client
|
||||||
Checks []*Checks `protobuf:"bytes,6,rep,name=Checks,proto3" json:"Checks,omitempty"`
|
Checks []*Checks `protobuf:"bytes,6,rep,name=Checks,proto3" json:"Checks,omitempty"`
|
||||||
|
// Indicates whether the client should skip updating the network map
|
||||||
|
SkipNetworkMapUpdate bool `protobuf:"varint,7,opt,name=skipNetworkMapUpdate,proto3" json:"skipNetworkMapUpdate,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *SyncResponse) Reset() {
|
func (x *SyncResponse) Reset() {
|
||||||
@@ -478,6 +490,13 @@ func (x *SyncResponse) GetChecks() []*Checks {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *SyncResponse) GetSkipNetworkMapUpdate() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.SkipNetworkMapUpdate
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type SyncMetaRequest struct {
|
type SyncMetaRequest struct {
|
||||||
state protoimpl.MessageState
|
state protoimpl.MessageState
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
@@ -3518,33 +3537,39 @@ var file_management_proto_rawDesc = []byte{
|
|||||||
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12,
|
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12,
|
||||||
0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62,
|
0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62,
|
||||||
0x6f, 0x64, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03,
|
0x6f, 0x64, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03,
|
||||||
0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3d, 0x0a,
|
0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x69, 0x0a,
|
||||||
0x0b, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a, 0x04,
|
0x0b, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a, 0x04,
|
||||||
0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e,
|
0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e,
|
||||||
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74,
|
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74,
|
||||||
0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x22, 0xdb, 0x02, 0x0a,
|
0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x12, 0x2a, 0x0a, 0x10,
|
||||||
0x0c, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a,
|
0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c,
|
||||||
0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01,
|
0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
|
||||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
|
0x61, 0x70, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x22, 0x8f, 0x03, 0x0a, 0x0c, 0x53, 0x79, 0x6e,
|
||||||
0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52,
|
0x63, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74,
|
||||||
0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36,
|
0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b,
|
||||||
0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01,
|
0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65,
|
||||||
0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
|
0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74,
|
||||||
0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72,
|
0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65,
|
||||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65,
|
0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16,
|
||||||
0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61,
|
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72,
|
||||||
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50,
|
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66,
|
||||||
0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74,
|
0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72,
|
||||||
0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65,
|
0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
|
||||||
0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01,
|
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43,
|
||||||
0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49,
|
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65,
|
||||||
0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x36, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
|
0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72,
|
||||||
0x6b, 0x4d, 0x61, 0x70, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e,
|
0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12,
|
||||||
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
|
0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70,
|
||||||
0x61, 0x70, 0x52, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x2a,
|
0x74, 0x79, 0x12, 0x36, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
|
||||||
0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12,
|
0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
|
||||||
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63,
|
0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x52, 0x0a,
|
||||||
0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x41, 0x0a, 0x0f, 0x53, 0x79,
|
0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x2a, 0x0a, 0x06, 0x43, 0x68,
|
||||||
|
0x65, 0x63, 0x6b, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e,
|
||||||
|
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x52, 0x06,
|
||||||
|
0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x73, 0x6b, 0x69, 0x70, 0x4e, 0x65,
|
||||||
|
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x07,
|
||||||
|
0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x73, 0x6b, 0x69, 0x70, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
|
||||||
|
0x6b, 0x4d, 0x61, 0x70, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0x41, 0x0a, 0x0f, 0x53, 0x79,
|
||||||
0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a,
|
0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a,
|
||||||
0x04, 0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61,
|
0x04, 0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61,
|
||||||
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73,
|
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73,
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ message EncryptedMessage {
|
|||||||
message SyncRequest {
|
message SyncRequest {
|
||||||
// Meta data of the peer
|
// Meta data of the peer
|
||||||
PeerSystemMeta meta = 1;
|
PeerSystemMeta meta = 1;
|
||||||
|
// Optional: last known NetworkMap serial number on the client
|
||||||
|
uint64 networkMapSerial = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
|
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
|
||||||
@@ -85,6 +87,9 @@ message SyncResponse {
|
|||||||
|
|
||||||
// Posture checks to be evaluated by client
|
// Posture checks to be evaluated by client
|
||||||
repeated Checks Checks = 6;
|
repeated Checks Checks = 6;
|
||||||
|
|
||||||
|
// Indicates whether the client should skip updating the network map
|
||||||
|
bool skipNetworkMapUpdate = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SyncMetaRequest {
|
message SyncMetaRequest {
|
||||||
|
|||||||
Reference in New Issue
Block a user