mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
Compare commits
9 Commits
fix/limit-
...
v0.30.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96d2207684 | ||
|
|
f942491b91 | ||
|
|
8c8900be57 | ||
|
|
cee95461d1 | ||
|
|
49e65109d2 | ||
|
|
d93dd4fc7f | ||
|
|
3a88ac78ff | ||
|
|
da3a053e2b | ||
|
|
0e95f16cdd |
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.14"
|
SIGN_PIPE_VER: "v0.0.15"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||||
@@ -223,4 +223,4 @@ jobs:
|
|||||||
repo: netbirdio/sign-pipelines
|
repo: netbirdio/sign-pipelines
|
||||||
ref: ${{ env.SIGN_PIPE_VER }}
|
ref: ${{ env.SIGN_PIPE_VER }}
|
||||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||||
|
|||||||
@@ -96,6 +96,9 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
universal_binaries:
|
||||||
|
- id: netbird
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
- builds:
|
- builds:
|
||||||
- netbird
|
- netbird
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ builds:
|
|||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
|
|
||||||
|
universal_binaries:
|
||||||
|
- id: netbird-ui-darwin
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
- builds:
|
- builds:
|
||||||
- netbird-ui-darwin
|
- netbird-ui-darwin
|
||||||
|
|||||||
@@ -433,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||||
intdir := "-i"
|
intdir := "-i"
|
||||||
|
lointdir := "-o"
|
||||||
if inverse {
|
if inverse {
|
||||||
intdir = "-o"
|
intdir = "-o"
|
||||||
|
lointdir = "-i"
|
||||||
}
|
}
|
||||||
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
|
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||||
|
|||||||
@@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *
|
|||||||
rule := &nftables.Rule{
|
rule := &nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: []expr.Any{
|
Exprs: getEstablishedExprs(1),
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeySTATE,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
|
||||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.InsertRule(rule)
|
conn.InsertRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getEstablishedExprs(register uint32) []expr.Any {
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: register,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: register,
|
||||||
|
DestRegister: register,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: register,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: []byte{0, 0, 0, 0},
|
Data: []byte{0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
|
&expr.Counter{},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
@@ -81,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.cleanUpDefaultForwardRules()
|
err = r.removeAcceptForwardRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
}
|
}
|
||||||
@@ -98,40 +99,7 @@ func (r *router) Reset() error {
|
|||||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||||
r.ipsetCounter.Clear()
|
r.ipsetCounter.Clear()
|
||||||
|
|
||||||
return r.cleanUpDefaultForwardRules()
|
return r.removeAcceptForwardRules()
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) cleanUpDefaultForwardRules() error {
|
|
||||||
if r.filterTable == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list chains: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chain := range chains {
|
|
||||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.conn.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
@@ -167,7 +135,9 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
r.acceptForwardRules()
|
if err := r.acceptForwardRules(); err != nil {
|
||||||
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
@@ -455,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||||
|
|
||||||
dir := expr.MetaKeyIIFNAME
|
dir := expr.MetaKeyIIFNAME
|
||||||
|
notDir := expr.MetaKeyOIFNAME
|
||||||
if pair.Inverse {
|
if pair.Inverse {
|
||||||
dir = expr.MetaKeyOIFNAME
|
dir = expr.MetaKeyOIFNAME
|
||||||
|
notDir = expr.MetaKeyIIFNAME
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lo := ifname("lo")
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
Key: dir,
|
Key: dir,
|
||||||
@@ -470,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: intf,
|
Data: intf,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||||
|
&expr.Meta{
|
||||||
|
Key: notDir,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: lo,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
exprs = append(exprs, sourceExp...)
|
||||||
@@ -577,19 +562,60 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
// that our traffic is not dropped by existing rules there.
|
// that our traffic is not dropped by existing rules there.
|
||||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
func (r *router) acceptForwardRules() {
|
func (r *router) acceptForwardRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fw := "iptables"
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
log.Debugf("Used %s to add accept forward rules", fw)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
|
ipt, err := iptables.New()
|
||||||
|
if err != nil {
|
||||||
|
// filter table exists but iptables is not
|
||||||
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
|
fw = "nftables"
|
||||||
|
return r.acceptForwardRulesNftables()
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.acceptForwardRulesIptables(ipt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
|
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||||
|
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
|
||||||
|
} else {
|
||||||
|
log.Debugf("added iptables rule: %v", rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) getAcceptForwardRules() [][]string {
|
||||||
|
intf := r.wgIface.Name()
|
||||||
|
return [][]string{
|
||||||
|
{"-i", intf, "-j", "ACCEPT"},
|
||||||
|
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptForwardRulesNftables() error {
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
// Rule for incoming interface (iif) with counter
|
// Rule for incoming interface (iif) with counter
|
||||||
iifRule := &nftables.Rule{
|
iifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: r.filterTable,
|
||||||
Chain: &nftables.Chain{
|
Chain: &nftables.Chain{
|
||||||
Name: "FORWARD",
|
Name: chainNameForward,
|
||||||
Table: r.filterTable,
|
Table: r.filterTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
Hooknum: nftables.ChainHookForward,
|
Hooknum: nftables.ChainHookForward,
|
||||||
@@ -609,6 +635,15 @@ func (r *router) acceptForwardRules() {
|
|||||||
}
|
}
|
||||||
r.conn.InsertRule(iifRule)
|
r.conn.InsertRule(iifRule)
|
||||||
|
|
||||||
|
oifExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: intf,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Rule for outgoing interface (oif) with counter
|
// Rule for outgoing interface (oif) with counter
|
||||||
oifRule := &nftables.Rule{
|
oifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: r.filterTable,
|
||||||
@@ -619,36 +654,72 @@ func (r *router) acceptForwardRules() {
|
|||||||
Hooknum: nftables.ChainHookForward,
|
Hooknum: nftables.ChainHookForward,
|
||||||
Priority: nftables.ChainPriorityFilter,
|
Priority: nftables.ChainPriorityFilter,
|
||||||
},
|
},
|
||||||
Exprs: []expr.Any{
|
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: intf,
|
|
||||||
},
|
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeySTATE,
|
|
||||||
Register: 2,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
|
||||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 2,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
},
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.conn.InsertRule(oifRule)
|
r.conn.InsertRule(oifRule)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptForwardRules() error {
|
||||||
|
if r.filterTable == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
|
ipt, err := iptables.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
return r.removeAcceptForwardRulesNftables()
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.removeAcceptForwardRulesIptables(ipt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptForwardRulesNftables() error {
|
||||||
|
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
|
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||||
|
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||||
|
|||||||
@@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(ifaceMock.Name()),
|
Data: ifname(ifaceMock.Name()),
|
||||||
},
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname("lo"),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
@@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(ifaceMock.Name()),
|
Data: ifname(ifaceMock.Name()),
|
||||||
},
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname("lo"),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||||
|
|||||||
@@ -82,8 +82,6 @@ type Conn struct {
|
|||||||
config ConnConfig
|
config ConnConfig
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
wgProxyFactory *wgproxy.Factory
|
wgProxyFactory *wgproxy.Factory
|
||||||
wgProxyICE wgproxy.Proxy
|
|
||||||
wgProxyRelay wgproxy.Proxy
|
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
@@ -106,7 +104,8 @@ type Conn struct {
|
|||||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||||
|
|
||||||
endpointRelay *net.UDPAddr
|
wgProxyICE wgproxy.Proxy
|
||||||
|
wgProxyRelay wgproxy.Proxy
|
||||||
|
|
||||||
// for reconnection operations
|
// for reconnection operations
|
||||||
iCEDisconnected chan bool
|
iCEDisconnected chan bool
|
||||||
@@ -257,8 +256,7 @@ func (conn *Conn) Close() {
|
|||||||
conn.wgProxyICE = nil
|
conn.wgProxyICE = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
|
|
||||||
conn.log.Debugf("ICE connection is ready")
|
conn.log.Debugf("ICE connection is ready")
|
||||||
|
|
||||||
conn.statusICE.Set(StatusConnected)
|
|
||||||
|
|
||||||
defer conn.updateIceState(iceConnInfo)
|
|
||||||
|
|
||||||
if conn.currentConnPriority > priority {
|
if conn.currentConnPriority > priority {
|
||||||
|
conn.statusICE.Set(StatusConnected)
|
||||||
|
conn.updateIceState(iceConnInfo)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Infof("set ICE to active connection")
|
conn.log.Infof("set ICE to active connection")
|
||||||
|
|
||||||
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
|
var (
|
||||||
if err != nil {
|
ep *net.UDPAddr
|
||||||
return
|
wgProxy wgproxy.Proxy
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if iceConnInfo.RelayedOnLocal {
|
||||||
|
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||||
|
if err != nil {
|
||||||
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ep = wgProxy.EndpointAddr()
|
||||||
|
conn.wgProxyICE = wgProxy
|
||||||
|
} else {
|
||||||
|
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to resolveUDPaddr")
|
||||||
|
conn.handleConfigurationFailure(err, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ep = directEp
|
||||||
}
|
}
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||||
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
|
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||||
|
|
||||||
conn.connIDICE = nbnet.GenerateConnID()
|
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
|
||||||
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
|
|
||||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
conn.workerRelay.DisableWgWatcher()
|
||||||
|
|
||||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
if conn.wgProxyRelay != nil {
|
||||||
if err != nil {
|
conn.wgProxyRelay.Pause()
|
||||||
if wgProxy != nil {
|
}
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("Failed to close turn connection: %v", err)
|
if wgProxy != nil {
|
||||||
}
|
wgProxy.Work()
|
||||||
}
|
}
|
||||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
|
||||||
|
if err = conn.configureWGEndpoint(ep); err != nil {
|
||||||
|
conn.handleConfigurationFailure(err, wgProxy)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
if conn.wgProxyICE != nil {
|
|
||||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.wgProxyICE = wgProxy
|
|
||||||
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
|
conn.statusICE.Set(StatusConnected)
|
||||||
|
conn.updateIceState(iceConnInfo)
|
||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
|
|
||||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
conn.log.Tracef("ICE connection state changed to %s", newState)
|
||||||
|
|
||||||
|
if conn.wgProxyICE != nil {
|
||||||
|
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||||
|
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// switch back to relay connection
|
// switch back to relay connection
|
||||||
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
||||||
err := conn.configureWGEndpoint(conn.endpointRelay)
|
conn.wgProxyRelay.Work()
|
||||||
if err != nil {
|
|
||||||
|
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
@@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||||
conn.statusICE.Set(newState)
|
conn.statusICE.Set(newState)
|
||||||
|
|
||||||
select {
|
conn.notifyReconnectLoopICEDisconnected(changed)
|
||||||
case conn.iCEDisconnected <- changed:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
if conn.ctx.Err() != nil {
|
if conn.ctx.Err() != nil {
|
||||||
if err := rci.relayedConn.Close(); err != nil {
|
if err := rci.relayedConn.Close(); err != nil {
|
||||||
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Debugf("Relay connection is ready to use")
|
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||||
conn.statusRelay.Set(StatusConnected)
|
|
||||||
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||||
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
conn.endpointRelay = endpointUdpAddr
|
|
||||||
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
|
||||||
|
|
||||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
if conn.iceP2PIsActive() {
|
||||||
|
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||||
if conn.currentConnPriority > connPriorityRelay {
|
conn.wgProxyRelay = wgProxy
|
||||||
if conn.statusICE.Get() == StatusConnected {
|
conn.statusRelay.Set(StatusConnected)
|
||||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.connIDRelay = nbnet.GenerateConnID()
|
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||||
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
|
|
||||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
wgProxy.Work()
|
||||||
if err != nil {
|
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
|
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
|
||||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.wgProxyRelay = wgProxy
|
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
|
conn.statusRelay.Set(StatusConnected)
|
||||||
|
conn.wgProxyRelay = wgProxy
|
||||||
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
conn.log.Infof("start to communicate with peer via relay")
|
conn.log.Infof("start to communicate with peer via relay")
|
||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
}
|
}
|
||||||
@@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("relay connection is disconnected")
|
conn.log.Debugf("relay connection is disconnected")
|
||||||
|
|
||||||
if conn.currentConnPriority == connPriorityRelay {
|
if conn.currentConnPriority == connPriorityRelay {
|
||||||
log.Debugf("clean up WireGuard config")
|
conn.log.Debugf("clean up WireGuard config")
|
||||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
conn.endpointRelay = nil
|
|
||||||
_ = conn.wgProxyRelay.CloseConn()
|
_ = conn.wgProxyRelay.CloseConn()
|
||||||
conn.wgProxyRelay = nil
|
conn.wgProxyRelay = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||||
conn.statusRelay.Set(StatusDisconnected)
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
|
conn.notifyReconnectLoopRelayDisconnected(changed)
|
||||||
select {
|
|
||||||
case conn.relayDisconnected <- changed:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
Relayed: conn.isRelayed(),
|
Relayed: conn.isRelayed(),
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
}
|
}
|
||||||
|
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||||
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
|
|
||||||
if err != nil {
|
|
||||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
||||||
|
conn.connIDICE = nbnet.GenerateConnID()
|
||||||
|
for _, hook := range conn.beforeAddPeerHooks {
|
||||||
|
if err := hook(conn.connIDICE, ip); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) freeUpConnID() {
|
func (conn *Conn) freeUpConnID() {
|
||||||
if conn.connIDRelay != "" {
|
if conn.connIDRelay != "" {
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
for _, hook := range conn.afterRemovePeerHooks {
|
||||||
@@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
|
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||||
if !iceConnInfo.RelayedOnLocal {
|
conn.log.Debugf("setup proxied WireGuard connection")
|
||||||
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
|
||||||
}
|
|
||||||
conn.log.Debugf("setup ice turn connection")
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||||
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
|
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
if errClose := wgProxy.CloseConn(); errClose != nil {
|
return nil, err
|
||||||
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
}
|
||||||
}
|
return wgProxy, nil
|
||||||
return nil, nil, err
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) isReadyToUpgrade() bool {
|
||||||
|
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) iceP2PIsActive() bool {
|
||||||
|
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) removeWgPeer() error {
|
||||||
|
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
|
||||||
|
select {
|
||||||
|
case conn.relayDisconnected <- changed:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
|
||||||
|
select {
|
||||||
|
case conn.iCEDisconnected <- changed:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||||
|
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||||
|
if wgProxy != nil {
|
||||||
|
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||||
|
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conn.wgProxyRelay != nil {
|
||||||
|
conn.wgProxyRelay.Work()
|
||||||
}
|
}
|
||||||
return ep, wgProxy, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ package ebpf
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn add new turn connection for the proxy
|
// AddTurnConn add new turn connection for the proxy
|
||||||
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
|
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
|
||||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
|
|
||||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||||
|
|
||||||
wgEndpoint := &net.UDPAddr{
|
wgEndpoint := &net.UDPAddr{
|
||||||
@@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
|
|
||||||
defer p.removeTurnConn(endpointPort)
|
|
||||||
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
n int
|
|
||||||
)
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
n, err = remoteConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != io.EOF {
|
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
|
|
||||||
if ctx.Err() != nil || p.ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||||
// From this go routine has only one instance.
|
// From this go routine has only one instance.
|
||||||
func (p *WGEBPFProxy) proxyToRemote() {
|
func (p *WGEBPFProxy) proxyToRemote() {
|
||||||
@@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
|||||||
return packetConn, nil
|
return packetConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||||
localhost := net.ParseIP("127.0.0.1")
|
localhost := net.ParseIP("127.0.0.1")
|
||||||
|
|
||||||
payload := gopacket.Payload(data)
|
payload := gopacket.Payload(data)
|
||||||
|
|||||||
@@ -4,8 +4,13 @@ package ebpf
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
@@ -13,20 +18,55 @@ type ProxyWrapper struct {
|
|||||||
WgeBPFProxy *WGEBPFProxy
|
WgeBPFProxy *WGEBPFProxy
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
wgEndpointAddr *net.UDPAddr
|
||||||
|
|
||||||
|
pausedMu sync.Mutex
|
||||||
|
paused bool
|
||||||
|
isStarted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||||
ctxConn, cancel := context.WithCancel(ctx)
|
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cancel()
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
return nil, fmt.Errorf("add turn conn: %w", err)
|
|
||||||
}
|
}
|
||||||
e.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
e.cancel = cancel
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
return addr, err
|
p.wgEndpointAddr = addr
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||||
|
return p.wgEndpointAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) Work() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if !p.isStarted {
|
||||||
|
p.isStarted = true
|
||||||
|
go p.proxyToLocal(p.ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) Pause() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = true
|
||||||
|
p.pausedMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||||
@@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||||
|
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
for {
|
||||||
|
n, err := p.readFromRemote(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
if p.paused {
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
|
||||||
|
n, err := p.remoteConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ import (
|
|||||||
|
|
||||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
|
AddTurnConn(ctx context.Context, turnConn net.Conn) error
|
||||||
|
EndpointAddr() *net.UDPAddr
|
||||||
|
Work()
|
||||||
|
Pause()
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
relayedConn := newMockConn()
|
relayedConn := newMockConn()
|
||||||
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error: %v", err)
|
t.Errorf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,13 +15,17 @@ import (
|
|||||||
// WGUserSpaceProxy proxies
|
// WGUserSpaceProxy proxies
|
||||||
type WGUserSpaceProxy struct {
|
type WGUserSpaceProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
closeMu sync.Mutex
|
closeMu sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
|
pausedMu sync.Mutex
|
||||||
|
paused bool
|
||||||
|
isStarted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||||
@@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
|||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
// AddTurnConn
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
// The provided Context must be non-nil. If the context expires before
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
// the connection is complete, an error is returned. Once successfully
|
||||||
|
// connected, any expiration of the context will not affect the
|
||||||
p.remoteConn = remoteConn
|
// connection.
|
||||||
|
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||||
var err error
|
|
||||||
dialer := net.Dialer{}
|
dialer := net.Dialer{}
|
||||||
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.proxyToRemote()
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
go p.proxyToLocal()
|
p.localConn = localConn
|
||||||
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
return p.localConn.LocalAddr(), err
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||||
|
if p.localConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||||
|
return endpointUdpAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work starts the proxy or resumes it if it was paused
|
||||||
|
func (p *WGUserSpaceProxy) Work() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if !p.isStarted {
|
||||||
|
p.isStarted = true
|
||||||
|
go p.proxyToRemote(p.ctx)
|
||||||
|
go p.proxyToLocal(p.ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pause pauses the proxy from receiving data from the remote peer
|
||||||
|
func (p *WGUserSpaceProxy) Pause() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = true
|
||||||
|
p.pausedMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the localConn
|
// CloseConn close the localConn
|
||||||
@@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := p.close(); err != nil {
|
if err := p.close(); err != nil {
|
||||||
log.Warnf("error in proxy to remote loop: %s", err)
|
log.Warnf("error in proxy to remote loop: %s", err)
|
||||||
@@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for p.ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
n, err := p.localConn.Read(buf)
|
n, err := p.localConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||||
@@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
|
|
||||||
_, err = p.remoteConn.Write(buf[:n])
|
_, err = p.remoteConn.Write(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||||
|
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := p.close(); err != nil {
|
if err := p.close(); err != nil {
|
||||||
log.Warnf("error in proxy to local loop: %s", err)
|
log.Warnf("error in proxy to local loop: %s", err)
|
||||||
@@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for p.ctx.Err() == nil {
|
for {
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
if p.paused {
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
_, err = p.localConn.Write(buf[:n])
|
_, err = p.localConn.Write(buf[:n])
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
|
||||||
<plist version="1.0">
|
|
||||||
<dict>
|
|
||||||
<key>CFBundleExecutable</key>
|
|
||||||
<string>netbird-ui</string>
|
|
||||||
<key>CFBundleIconFile</key>
|
|
||||||
<string>Netbird</string>
|
|
||||||
<key>LSUIElement</key>
|
|
||||||
<string>1</string>
|
|
||||||
</dict>
|
|
||||||
</plist>
|
|
||||||
@@ -873,7 +873,7 @@ services:
|
|||||||
zitadel:
|
zitadel:
|
||||||
restart: 'always'
|
restart: 'always'
|
||||||
networks: [netbird]
|
networks: [netbird]
|
||||||
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
|
image: 'ghcr.io/zitadel/zitadel:v2.54.10'
|
||||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||||
env_file:
|
env_file:
|
||||||
- ./zitadel.env
|
- ./zitadel.env
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
b64 "encoding/base64"
|
b64 "encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@@ -50,6 +51,9 @@ const (
|
|||||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||||
|
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||||
|
emptyUserID = "empty user ID in claims"
|
||||||
|
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||||
)
|
)
|
||||||
|
|
||||||
type userLoggedInOnce bool
|
type userLoggedInOnce bool
|
||||||
@@ -178,6 +182,8 @@ type DefaultAccountManager struct {
|
|||||||
dnsDomain string
|
dnsDomain string
|
||||||
peerLoginExpiry Scheduler
|
peerLoginExpiry Scheduler
|
||||||
|
|
||||||
|
peerInactivityExpiry Scheduler
|
||||||
|
|
||||||
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
|
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
|
||||||
userDeleteFromIDPEnabled bool
|
userDeleteFromIDPEnabled bool
|
||||||
|
|
||||||
@@ -195,6 +201,13 @@ type Settings struct {
|
|||||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||||
PeerLoginExpiration time.Duration
|
PeerLoginExpiration time.Duration
|
||||||
|
|
||||||
|
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
|
||||||
|
PeerInactivityExpirationEnabled bool
|
||||||
|
|
||||||
|
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
|
||||||
|
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
|
||||||
|
PeerInactivityExpiration time.Duration
|
||||||
|
|
||||||
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
|
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
|
||||||
RegularUsersViewBlocked bool
|
RegularUsersViewBlocked bool
|
||||||
|
|
||||||
@@ -225,6 +238,9 @@ func (s *Settings) Copy() *Settings {
|
|||||||
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
||||||
JWTAllowGroups: s.JWTAllowGroups,
|
JWTAllowGroups: s.JWTAllowGroups,
|
||||||
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
|
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
|
||||||
|
PeerInactivityExpiration: s.PeerInactivityExpiration,
|
||||||
}
|
}
|
||||||
if s.Extra != nil {
|
if s.Extra != nil {
|
||||||
settings.Extra = s.Extra.Copy()
|
settings.Extra = s.Extra.Copy()
|
||||||
@@ -606,6 +622,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer {
|
|||||||
return peers
|
return peers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetInactivePeers returns peers that have been expired by inactivity
|
||||||
|
func (a *Account) GetInactivePeers() []*nbpeer.Peer {
|
||||||
|
var peers []*nbpeer.Peer
|
||||||
|
for _, inactivePeer := range a.GetPeersWithInactivity() {
|
||||||
|
inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration)
|
||||||
|
if inactive {
|
||||||
|
peers = append(peers, inactivePeer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||||
|
// If there is no peer that expires this function returns false and a duration of 0.
|
||||||
|
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||||
|
func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) {
|
||||||
|
peersWithExpiry := a.GetPeersWithInactivity()
|
||||||
|
if len(peersWithExpiry) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
var nextExpiry *time.Duration
|
||||||
|
for _, peer := range peersWithExpiry {
|
||||||
|
if peer.Status.LoginExpired || peer.Status.Connected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration)
|
||||||
|
if nextExpiry == nil || duration < *nextExpiry {
|
||||||
|
// if expiration is below 1s return 1s duration
|
||||||
|
// this avoids issues with ticker that can't be set to < 0
|
||||||
|
if duration < time.Second {
|
||||||
|
return time.Second, true
|
||||||
|
}
|
||||||
|
nextExpiry = &duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextExpiry == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *nextExpiry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user
|
||||||
|
func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer {
|
||||||
|
peers := make([]*nbpeer.Peer, 0)
|
||||||
|
for _, peer := range a.Peers {
|
||||||
|
if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeers returns a list of all Account peers
|
// GetPeers returns a list of all Account peers
|
||||||
func (a *Account) GetPeers() []*nbpeer.Peer {
|
func (a *Account) GetPeers() []*nbpeer.Peer {
|
||||||
var peers []*nbpeer.Peer
|
var peers []*nbpeer.Peer
|
||||||
@@ -972,6 +1042,7 @@ func BuildManager(
|
|||||||
dnsDomain: dnsDomain,
|
dnsDomain: dnsDomain,
|
||||||
eventStore: eventStore,
|
eventStore: eventStore,
|
||||||
peerLoginExpiry: NewDefaultScheduler(),
|
peerLoginExpiry: NewDefaultScheduler(),
|
||||||
|
peerInactivityExpiry: NewDefaultScheduler(),
|
||||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||||
integratedPeerValidator: integratedPeerValidator,
|
integratedPeerValidator: integratedPeerValidator,
|
||||||
metrics: metrics,
|
metrics: metrics,
|
||||||
@@ -1100,6 +1171,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
updatedAccount := account.UpdateSettings(newSettings)
|
updatedAccount := account.UpdateSettings(newSettings)
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
err = am.Store.SaveAccount(ctx, account)
|
||||||
@@ -1110,6 +1186,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return updatedAccount, nil
|
return updatedAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||||
|
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||||
|
event := activity.AccountPeerInactivityExpirationEnabled
|
||||||
|
if !newSettings.PeerInactivityExpirationEnabled {
|
||||||
|
event = activity.AccountPeerInactivityExpirationDisabled
|
||||||
|
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||||
|
} else {
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||||
return func() (time.Duration, bool) {
|
return func() (time.Duration, bool) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
@@ -1145,6 +1241,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
|
||||||
|
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||||
|
return func() (time.Duration, bool) {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed getting account %s expiring peers", account.Id)
|
||||||
|
return account.GetNextInactivePeerExpiration()
|
||||||
|
}
|
||||||
|
|
||||||
|
expiredPeers := account.GetInactivePeers()
|
||||||
|
var peerIDs []string
|
||||||
|
for _, peer := range expiredPeers {
|
||||||
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
||||||
|
|
||||||
|
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
|
||||||
|
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
|
||||||
|
return account.GetNextInactivePeerExpiration()
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.GetNextInactivePeerExpiration()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
|
||||||
|
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
|
||||||
|
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
|
||||||
|
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
|
||||||
|
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||||
// If ID is already in use (due to collision) we try one more time before returning error
|
// If ID is already in use (due to collision) we try one more time before returning error
|
||||||
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
|
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
|
||||||
@@ -1285,7 +1418,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
|||||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return account.Id, nil
|
return account.Id, nil
|
||||||
@@ -1300,28 +1433,39 @@ func isNil(i idp.Manager) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
|
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
|
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cachedAccount := &Account{
|
||||||
|
Id: accountID,
|
||||||
|
Users: make(map[string]*User),
|
||||||
|
}
|
||||||
|
for _, user := range accountUsers {
|
||||||
|
cachedAccount.Users[user.Id] = user
|
||||||
|
}
|
||||||
|
|
||||||
// user can be nil if it wasn't found (e.g., just created)
|
// user can be nil if it wasn't found (e.g., just created)
|
||||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if user != nil && user.AppMetadata.WTAccountID == account.Id {
|
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
||||||
// it was already set, so we skip the unnecessary update
|
// it was already set, so we skip the unnecessary update
|
||||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||||
account.Id, userID)
|
accountID, userID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
|
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||||
}
|
}
|
||||||
// refresh cache to reflect the update
|
// refresh cache to reflect the update
|
||||||
_, err = am.refreshCache(ctx, account.Id)
|
_, err = am.refreshCache(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1545,48 +1689,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
|||||||
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
|
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims,
|
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
|
||||||
primaryDomain bool,
|
primaryDomain bool,
|
||||||
) error {
|
) error {
|
||||||
|
if claims.Domain == "" {
|
||||||
if claims.Domain != "" {
|
|
||||||
account.IsDomainPrimaryAccount = primaryDomain
|
|
||||||
|
|
||||||
lowerDomain := strings.ToLower(claims.Domain)
|
|
||||||
userObj := account.Users[claims.UserId]
|
|
||||||
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
|
||||||
account.Domain = lowerDomain
|
|
||||||
}
|
|
||||||
// prevent updating category for different domain until admin logs in
|
|
||||||
if account.Domain == lowerDomain {
|
|
||||||
account.DomainCategory = claims.DomainCategory
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := am.Store.SaveAccount(ctx, account)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlockAccount()
|
||||||
|
|
||||||
|
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newDomain := accountDomain
|
||||||
|
newCategoty := domainCategory
|
||||||
|
|
||||||
|
lowerDomain := strings.ToLower(claims.Domain)
|
||||||
|
if accountDomain != lowerDomain && user.HasAdminPower() {
|
||||||
|
newDomain = lowerDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountDomain == lowerDomain {
|
||||||
|
newCategoty = claims.DomainCategory
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||||
|
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||||
|
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||||
|
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||||
|
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||||
|
// and peers that shouldn't be lost.
|
||||||
func (am *DefaultAccountManager) handleExistingUserAccount(
|
func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
existingAcc *Account,
|
userAccountID string,
|
||||||
primaryDomain bool,
|
domainAccountID string,
|
||||||
claims jwtclaims.AuthorizationClaims,
|
claims jwtclaims.AuthorizationClaims,
|
||||||
) error {
|
) error {
|
||||||
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
|
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||||
|
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// we should register the account ID to this user's metadata in our IDP manager
|
// we should register the account ID to this user's metadata in our IDP manager
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
|
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1594,44 +1759,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||||
// otherwise it will create a new account and make it primary account for the domain.
|
// otherwise it will create a new account and make it primary account for the domain.
|
||||||
func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return nil, fmt.Errorf("user ID is empty")
|
return "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
var (
|
|
||||||
account *Account
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
lowerDomain := strings.ToLower(claims.Domain)
|
lowerDomain := strings.ToLower(claims.Domain)
|
||||||
// if domain already has a primary account, add regular user
|
|
||||||
if domainAcc != nil {
|
|
||||||
account = domainAcc
|
|
||||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
account, err = am.newAccount(ctx, claims.UserId, lowerDomain)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = am.updateAccountDomainAttributes(ctx, account, claims, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
|
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
|
newAccount.Domain = lowerDomain
|
||||||
|
newAccount.DomainCategory = claims.DomainCategory
|
||||||
|
newAccount.IsDomainPrimaryAccount = true
|
||||||
|
|
||||||
return account, nil
|
err = am.Store.SaveAccount(ctx, newAccount)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||||
|
|
||||||
|
return newAccount.Id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||||
|
defer unlockAccount()
|
||||||
|
|
||||||
|
usersMap := make(map[string]*User)
|
||||||
|
usersMap[claims.UserId] = NewRegularUser(claims.UserId)
|
||||||
|
err := am.Store.SaveUsers(domainAccountID, usersMap)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
|
||||||
|
|
||||||
|
return domainAccountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// redeemInvite checks whether user has been invited and redeems the invite
|
// redeemInvite checks whether user has been invited and redeems the invite
|
||||||
@@ -1775,7 +1954,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
|||||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return "", "", fmt.Errorf("user ID is empty")
|
return "", "", errors.New(emptyUserID)
|
||||||
}
|
}
|
||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
@@ -1961,16 +2140,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||||
|
// if domain is not private or domain is invalid, it will return the account ID by user ID.
|
||||||
// if domain is of the PrivateCategory category, it will evaluate
|
// if domain is of the PrivateCategory category, it will evaluate
|
||||||
// if account is new, existing or if there is another account with the same domain
|
// if account is new, existing or if there is another account with the same domain
|
||||||
//
|
//
|
||||||
// Use cases:
|
// Use cases:
|
||||||
//
|
//
|
||||||
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
|
// New user + New account + New domain -> create account, user role = owner (if private domain, index domain)
|
||||||
//
|
//
|
||||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
|
// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin)
|
||||||
//
|
//
|
||||||
// New user + New account + Existing Public Domain -> create account, user role = admin
|
// New user + New account + Existing Public Domain -> create account, user role = owner
|
||||||
//
|
//
|
||||||
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
|
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
|
||||||
//
|
//
|
||||||
@@ -1980,98 +2160,123 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||||
|
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return "", fmt.Errorf("user ID is empty")
|
return "", errors.New(emptyUserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if Account ID is part of the claims
|
|
||||||
// it means that we've already classified the domain and user has an account
|
|
||||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
if claims.AccountId != "" {
|
|
||||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
|
|
||||||
}
|
|
||||||
return claims.AccountId, nil
|
|
||||||
}
|
|
||||||
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||||
} else if claims.AccountId != "" {
|
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if userAccountID != claims.AccountId {
|
|
||||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
|
||||||
}
|
|
||||||
|
|
||||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
|
||||||
return userAccountID, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
start := time.Now()
|
if claims.AccountId != "" {
|
||||||
unlock := am.Store.AcquireGlobalLock(ctx)
|
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
|
||||||
defer unlock()
|
}
|
||||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
|
||||||
|
|
||||||
// We checked if the domain has a primary account already
|
// We checked if the domain has a primary account already
|
||||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
|
||||||
|
if cancel != nil {
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// if NotFound we are good to continue, otherwise return error
|
return "", err
|
||||||
e, ok := status.FromError(err)
|
|
||||||
if !ok || e.Type() != status.NotFound {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err == nil {
|
if handleNotFound(err) != nil {
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||||
defer unlockAccount()
|
|
||||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
|
||||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
|
||||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
|
||||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
|
||||||
// and peers that shouldn't be lost.
|
|
||||||
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
|
||||||
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Id, nil
|
|
||||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
|
||||||
var domainAccount *Account
|
|
||||||
if domainAccountID != "" {
|
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
|
||||||
defer unlockAccount()
|
|
||||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return account.Id, nil
|
|
||||||
} else {
|
|
||||||
// other error
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if userAccountID != "" {
|
||||||
|
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return userAccountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if domainAccountID != "" {
|
||||||
|
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
|
||||||
|
}
|
||||||
|
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||||
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if domainAccountID != "" {
|
||||||
|
return domainAccountID, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain)
|
||||||
|
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||||
|
|
||||||
|
// check again if the domain has a primary account because of simultaneous requests
|
||||||
|
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return domainAccountID, cancel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if userAccountID != claims.AccountId {
|
||||||
|
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||||
|
}
|
||||||
|
|
||||||
|
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||||
|
return claims.AccountId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// We checked if the domain has a primary account already
|
||||||
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims.AccountId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNotFound(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
e, ok := status.FromError(err)
|
||||||
|
if !ok || e.Type() != status.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
|
||||||
|
return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||||
@@ -2337,6 +2542,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
|||||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||||
GroupsPropagationEnabled: true,
|
GroupsPropagationEnabled: true,
|
||||||
RegularUsersViewBlocked: true,
|
RegularUsersViewBlocked: true,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: false,
|
||||||
|
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
|||||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||||
type initUserParams jwtclaims.AuthorizationClaims
|
type initUserParams jwtclaims.AuthorizationClaims
|
||||||
|
|
||||||
type test struct {
|
var (
|
||||||
|
publicDomain = "public.com"
|
||||||
|
privateDomain = "private.com"
|
||||||
|
unknownDomain = "unknown.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
defaultInitAccount := initUserParams{
|
||||||
|
Domain: publicDomain,
|
||||||
|
UserId: "defaultUser",
|
||||||
|
}
|
||||||
|
|
||||||
|
initUnknown := defaultInitAccount
|
||||||
|
initUnknown.DomainCategory = UnknownCategory
|
||||||
|
initUnknown.Domain = unknownDomain
|
||||||
|
|
||||||
|
privateInitAccount := defaultInitAccount
|
||||||
|
privateInitAccount.Domain = privateDomain
|
||||||
|
privateInitAccount.DomainCategory = PrivateCategory
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
inputClaims jwtclaims.AuthorizationClaims
|
inputClaims jwtclaims.AuthorizationClaims
|
||||||
inputInitUserParams initUserParams
|
inputInitUserParams initUserParams
|
||||||
@@ -479,156 +498,131 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
expectedPrimaryDomainStatus bool
|
expectedPrimaryDomainStatus bool
|
||||||
expectedCreatedBy string
|
expectedCreatedBy string
|
||||||
expectedUsers []string
|
expectedUsers []string
|
||||||
}
|
}{
|
||||||
|
{
|
||||||
var (
|
name: "New User With Public Domain",
|
||||||
publicDomain = "public.com"
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
privateDomain = "private.com"
|
Domain: publicDomain,
|
||||||
unknownDomain = "unknown.com"
|
UserId: "pub-domain-user",
|
||||||
)
|
DomainCategory: PublicCategory,
|
||||||
|
},
|
||||||
defaultInitAccount := initUserParams{
|
inputInitUserParams: defaultInitAccount,
|
||||||
Domain: publicDomain,
|
testingFunc: require.NotEqual,
|
||||||
UserId: "defaultUser",
|
expectedMSG: "account IDs shouldn't match",
|
||||||
}
|
expectedUserRole: UserRoleOwner,
|
||||||
|
expectedDomainCategory: "",
|
||||||
testCase1 := test{
|
expectedDomain: publicDomain,
|
||||||
name: "New User With Public Domain",
|
expectedPrimaryDomainStatus: false,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedCreatedBy: "pub-domain-user",
|
||||||
Domain: publicDomain,
|
expectedUsers: []string{"pub-domain-user"},
|
||||||
UserId: "pub-domain-user",
|
|
||||||
DomainCategory: PublicCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: defaultInitAccount,
|
{
|
||||||
testingFunc: require.NotEqual,
|
name: "New User With Unknown Domain",
|
||||||
expectedMSG: "account IDs shouldn't match",
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedUserRole: UserRoleOwner,
|
Domain: unknownDomain,
|
||||||
expectedDomainCategory: "",
|
UserId: "unknown-domain-user",
|
||||||
expectedDomain: publicDomain,
|
DomainCategory: UnknownCategory,
|
||||||
expectedPrimaryDomainStatus: false,
|
},
|
||||||
expectedCreatedBy: "pub-domain-user",
|
inputInitUserParams: initUnknown,
|
||||||
expectedUsers: []string{"pub-domain-user"},
|
testingFunc: require.NotEqual,
|
||||||
}
|
expectedMSG: "account IDs shouldn't match",
|
||||||
|
expectedUserRole: UserRoleOwner,
|
||||||
initUnknown := defaultInitAccount
|
expectedDomain: unknownDomain,
|
||||||
initUnknown.DomainCategory = UnknownCategory
|
expectedDomainCategory: "",
|
||||||
initUnknown.Domain = unknownDomain
|
expectedPrimaryDomainStatus: false,
|
||||||
|
expectedCreatedBy: "unknown-domain-user",
|
||||||
testCase2 := test{
|
expectedUsers: []string{"unknown-domain-user"},
|
||||||
name: "New User With Unknown Domain",
|
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
|
||||||
Domain: unknownDomain,
|
|
||||||
UserId: "unknown-domain-user",
|
|
||||||
DomainCategory: UnknownCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: initUnknown,
|
{
|
||||||
testingFunc: require.NotEqual,
|
name: "New User With Private Domain",
|
||||||
expectedMSG: "account IDs shouldn't match",
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedUserRole: UserRoleOwner,
|
Domain: privateDomain,
|
||||||
expectedDomain: unknownDomain,
|
UserId: "pvt-domain-user",
|
||||||
expectedDomainCategory: "",
|
DomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: false,
|
},
|
||||||
expectedCreatedBy: "unknown-domain-user",
|
inputInitUserParams: defaultInitAccount,
|
||||||
expectedUsers: []string{"unknown-domain-user"},
|
testingFunc: require.NotEqual,
|
||||||
}
|
expectedMSG: "account IDs shouldn't match",
|
||||||
|
expectedUserRole: UserRoleOwner,
|
||||||
testCase3 := test{
|
expectedDomain: privateDomain,
|
||||||
name: "New User With Private Domain",
|
expectedDomainCategory: PrivateCategory,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedPrimaryDomainStatus: true,
|
||||||
Domain: privateDomain,
|
expectedCreatedBy: "pvt-domain-user",
|
||||||
UserId: "pvt-domain-user",
|
expectedUsers: []string{"pvt-domain-user"},
|
||||||
DomainCategory: PrivateCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: defaultInitAccount,
|
{
|
||||||
testingFunc: require.NotEqual,
|
name: "New Regular User With Existing Private Domain",
|
||||||
expectedMSG: "account IDs shouldn't match",
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedUserRole: UserRoleOwner,
|
Domain: privateDomain,
|
||||||
expectedDomain: privateDomain,
|
UserId: "new-pvt-domain-user",
|
||||||
expectedDomainCategory: PrivateCategory,
|
DomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: true,
|
},
|
||||||
expectedCreatedBy: "pvt-domain-user",
|
inputUpdateAttrs: true,
|
||||||
expectedUsers: []string{"pvt-domain-user"},
|
inputInitUserParams: privateInitAccount,
|
||||||
}
|
testingFunc: require.Equal,
|
||||||
|
expectedMSG: "account IDs should match",
|
||||||
privateInitAccount := defaultInitAccount
|
expectedUserRole: UserRoleUser,
|
||||||
privateInitAccount.Domain = privateDomain
|
expectedDomain: privateDomain,
|
||||||
privateInitAccount.DomainCategory = PrivateCategory
|
expectedDomainCategory: PrivateCategory,
|
||||||
|
expectedPrimaryDomainStatus: true,
|
||||||
testCase4 := test{
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
name: "New Regular User With Existing Private Domain",
|
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
|
||||||
Domain: privateDomain,
|
|
||||||
UserId: "new-pvt-domain-user",
|
|
||||||
DomainCategory: PrivateCategory,
|
|
||||||
},
|
},
|
||||||
inputUpdateAttrs: true,
|
{
|
||||||
inputInitUserParams: privateInitAccount,
|
name: "Existing User With Existing Reclassified Private Domain",
|
||||||
testingFunc: require.Equal,
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedMSG: "account IDs should match",
|
Domain: defaultInitAccount.Domain,
|
||||||
expectedUserRole: UserRoleUser,
|
UserId: defaultInitAccount.UserId,
|
||||||
expectedDomain: privateDomain,
|
DomainCategory: PrivateCategory,
|
||||||
expectedDomainCategory: PrivateCategory,
|
},
|
||||||
expectedPrimaryDomainStatus: true,
|
inputInitUserParams: defaultInitAccount,
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
testingFunc: require.Equal,
|
||||||
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
expectedMSG: "account IDs should match",
|
||||||
}
|
expectedUserRole: UserRoleOwner,
|
||||||
|
expectedDomain: defaultInitAccount.Domain,
|
||||||
testCase5 := test{
|
expectedDomainCategory: PrivateCategory,
|
||||||
name: "Existing User With Existing Reclassified Private Domain",
|
expectedPrimaryDomainStatus: true,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
Domain: defaultInitAccount.Domain,
|
expectedUsers: []string{defaultInitAccount.UserId},
|
||||||
UserId: defaultInitAccount.UserId,
|
|
||||||
DomainCategory: PrivateCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: defaultInitAccount,
|
{
|
||||||
testingFunc: require.Equal,
|
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||||
expectedMSG: "account IDs should match",
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedUserRole: UserRoleOwner,
|
Domain: defaultInitAccount.Domain,
|
||||||
expectedDomain: defaultInitAccount.Domain,
|
UserId: defaultInitAccount.UserId,
|
||||||
expectedDomainCategory: PrivateCategory,
|
DomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: true,
|
},
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
inputUpdateClaimAccount: true,
|
||||||
expectedUsers: []string{defaultInitAccount.UserId},
|
inputInitUserParams: defaultInitAccount,
|
||||||
}
|
testingFunc: require.Equal,
|
||||||
|
expectedMSG: "account IDs should match",
|
||||||
testCase6 := test{
|
expectedUserRole: UserRoleOwner,
|
||||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
expectedDomain: defaultInitAccount.Domain,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedDomainCategory: PrivateCategory,
|
||||||
Domain: defaultInitAccount.Domain,
|
expectedPrimaryDomainStatus: true,
|
||||||
UserId: defaultInitAccount.UserId,
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
DomainCategory: PrivateCategory,
|
expectedUsers: []string{defaultInitAccount.UserId},
|
||||||
},
|
},
|
||||||
inputUpdateClaimAccount: true,
|
{
|
||||||
inputInitUserParams: defaultInitAccount,
|
name: "User With Private Category And Empty Domain",
|
||||||
testingFunc: require.Equal,
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedMSG: "account IDs should match",
|
Domain: "",
|
||||||
expectedUserRole: UserRoleOwner,
|
UserId: "pvt-domain-user",
|
||||||
expectedDomain: defaultInitAccount.Domain,
|
DomainCategory: PrivateCategory,
|
||||||
expectedDomainCategory: PrivateCategory,
|
},
|
||||||
expectedPrimaryDomainStatus: true,
|
inputInitUserParams: defaultInitAccount,
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
testingFunc: require.NotEqual,
|
||||||
expectedUsers: []string{defaultInitAccount.UserId},
|
expectedMSG: "account IDs shouldn't match",
|
||||||
}
|
expectedUserRole: UserRoleOwner,
|
||||||
|
expectedDomain: "",
|
||||||
testCase7 := test{
|
expectedDomainCategory: "",
|
||||||
name: "User With Private Category And Empty Domain",
|
expectedPrimaryDomainStatus: false,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedCreatedBy: "pvt-domain-user",
|
||||||
Domain: "",
|
expectedUsers: []string{"pvt-domain-user"},
|
||||||
UserId: "pvt-domain-user",
|
|
||||||
DomainCategory: PrivateCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: defaultInitAccount,
|
|
||||||
testingFunc: require.NotEqual,
|
|
||||||
expectedMSG: "account IDs shouldn't match",
|
|
||||||
expectedUserRole: UserRoleOwner,
|
|
||||||
expectedDomain: "",
|
|
||||||
expectedDomainCategory: "",
|
|
||||||
expectedPrimaryDomainStatus: false,
|
|
||||||
expectedCreatedBy: "pvt-domain-user",
|
|
||||||
expectedUsers: []string{"pvt-domain-user"},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
@@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "get init account failed")
|
require.NoError(t, err, "get init account failed")
|
||||||
|
|
||||||
if testCase.inputUpdateAttrs {
|
if testCase.inputUpdateAttrs {
|
||||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||||
require.NoError(t, err, "update init user failed")
|
require.NoError(t, err, "update init user failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1963,6 +1957,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*nbpeer.Peer
|
||||||
|
expectedPeers map[string]struct{}
|
||||||
|
}
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "Peers with inactivity expiration disabled, no expired peers",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Two peers expired",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
ID: "peer-1",
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
LastSeen: time.Now().UTC().Add(-45 * time.Second),
|
||||||
|
Connected: false,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().UTC().Add(-30 * time.Minute),
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
ID: "peer-2",
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
LastSeen: time.Now().UTC().Add(-45 * time.Second),
|
||||||
|
Connected: false,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().UTC().Add(-2 * time.Hour),
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-3": {
|
||||||
|
ID: "peer-3",
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
LastSeen: time.Now().UTC(),
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().UTC().Add(-1 * time.Hour),
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{
|
||||||
|
"peer-1": {},
|
||||||
|
"peer-2": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
Settings: &Settings{
|
||||||
|
PeerInactivityExpirationEnabled: true,
|
||||||
|
PeerInactivityExpiration: time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expiredPeers := account.GetInactivePeers()
|
||||||
|
assert.Len(t, expiredPeers, len(testCase.expectedPeers))
|
||||||
|
for _, peer := range expiredPeers {
|
||||||
|
if _, ok := testCase.expectedPeers[peer.ID]; !ok {
|
||||||
|
t.Fatalf("expected to have peer %s expired", peer.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
@@ -2032,6 +2110,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetPeersWithInactivity(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*nbpeer.Peer
|
||||||
|
expectedPeers map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "No account peers, no peers with expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peers with login expiration disabled, no peers with expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peers with login expiration enabled, return peers with expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
ID: "peer-1",
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{
|
||||||
|
"peer-1": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := account.GetPeersWithInactivity()
|
||||||
|
assert.Len(t, actual, len(testCase.expectedPeers))
|
||||||
|
if len(testCase.expectedPeers) > 0 {
|
||||||
|
for k := range testCase.expectedPeers {
|
||||||
|
contains := false
|
||||||
|
for _, peer := range actual {
|
||||||
|
if k == peer.ID {
|
||||||
|
contains = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, contains)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
@@ -2193,6 +2340,168 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*nbpeer.Peer
|
||||||
|
expiration time.Duration
|
||||||
|
expirationEnabled bool
|
||||||
|
expectedNextRun bool
|
||||||
|
expectedNextExpiration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedNextExpiration := time.Minute
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "No peers, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No connected peers, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: false,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: false,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Connected peers with disabled expiration, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Expired peers, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To be expired peer, return expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: false,
|
||||||
|
LoginExpired: false,
|
||||||
|
LastSeen: time.Now().Add(-1 * time.Second),
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
LastLogin: time.Now().UTC(),
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Minute,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: true,
|
||||||
|
expectedNextExpiration: expectedNextExpiration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peers added with setup keys, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
SetupKey: "key",
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
SetupKey: "key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
|
||||||
|
}
|
||||||
|
|
||||||
|
expiration, ok := account.GetNextInactivePeerExpiration()
|
||||||
|
assert.Equal(t, testCase.expectedNextRun, ok)
|
||||||
|
if testCase.expectedNextRun {
|
||||||
|
assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, expiration, testCase.expectedNextExpiration)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|||||||
@@ -139,6 +139,13 @@ const (
|
|||||||
PostureCheckUpdated Activity = 61
|
PostureCheckUpdated Activity = 61
|
||||||
// PostureCheckDeleted indicates that the user deleted a posture check
|
// PostureCheckDeleted indicates that the user deleted a posture check
|
||||||
PostureCheckDeleted Activity = 62
|
PostureCheckDeleted Activity = 62
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled Activity = 63
|
||||||
|
PeerInactivityExpirationDisabled Activity = 64
|
||||||
|
|
||||||
|
AccountPeerInactivityExpirationEnabled Activity = 65
|
||||||
|
AccountPeerInactivityExpirationDisabled Activity = 66
|
||||||
|
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||||
)
|
)
|
||||||
|
|
||||||
var activityMap = map[Activity]Code{
|
var activityMap = map[Activity]Code{
|
||||||
@@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{
|
|||||||
PostureCheckCreated: {"Posture check created", "posture.check.created"},
|
PostureCheckCreated: {"Posture check created", "posture.check.created"},
|
||||||
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
|
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
|
||||||
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
|
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"},
|
||||||
|
PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"},
|
||||||
|
|
||||||
|
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
|
||||||
|
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||||
|
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
@@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
|||||||
account.Settings = &Settings{
|
account.Settings = &Settings{
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: false,
|
||||||
|
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||||
|
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Settings.Extra != nil {
|
if req.Settings.Extra != nil {
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ components:
|
|||||||
description: Period of time after which peer login expires (seconds).
|
description: Period of time after which peer login expires (seconds).
|
||||||
type: integer
|
type: integer
|
||||||
example: 43200
|
example: 43200
|
||||||
|
peer_inactivity_expiration_enabled:
|
||||||
|
description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
peer_inactivity_expiration:
|
||||||
|
description: Period of time of inactivity after which peer session expires (seconds).
|
||||||
|
type: integer
|
||||||
|
example: 43200
|
||||||
regular_users_view_blocked:
|
regular_users_view_blocked:
|
||||||
description: Allows blocking regular users from viewing parts of the system.
|
description: Allows blocking regular users from viewing parts of the system.
|
||||||
type: boolean
|
type: boolean
|
||||||
@@ -81,6 +89,8 @@ components:
|
|||||||
required:
|
required:
|
||||||
- peer_login_expiration_enabled
|
- peer_login_expiration_enabled
|
||||||
- peer_login_expiration
|
- peer_login_expiration
|
||||||
|
- peer_inactivity_expiration_enabled
|
||||||
|
- peer_inactivity_expiration
|
||||||
- regular_users_view_blocked
|
- regular_users_view_blocked
|
||||||
AccountExtraSettings:
|
AccountExtraSettings:
|
||||||
type: object
|
type: object
|
||||||
@@ -243,6 +253,9 @@ components:
|
|||||||
login_expiration_enabled:
|
login_expiration_enabled:
|
||||||
type: boolean
|
type: boolean
|
||||||
example: false
|
example: false
|
||||||
|
inactivity_expiration_enabled:
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
approval_required:
|
approval_required:
|
||||||
description: (Cloud only) Indicates whether peer needs approval
|
description: (Cloud only) Indicates whether peer needs approval
|
||||||
type: boolean
|
type: boolean
|
||||||
@@ -251,6 +264,7 @@ components:
|
|||||||
- name
|
- name
|
||||||
- ssh_enabled
|
- ssh_enabled
|
||||||
- login_expiration_enabled
|
- login_expiration_enabled
|
||||||
|
- inactivity_expiration_enabled
|
||||||
Peer:
|
Peer:
|
||||||
allOf:
|
allOf:
|
||||||
- $ref: '#/components/schemas/PeerMinimum'
|
- $ref: '#/components/schemas/PeerMinimum'
|
||||||
@@ -327,6 +341,10 @@ components:
|
|||||||
type: string
|
type: string
|
||||||
format: date-time
|
format: date-time
|
||||||
example: "2023-05-05T09:00:35.477782Z"
|
example: "2023-05-05T09:00:35.477782Z"
|
||||||
|
inactivity_expiration_enabled:
|
||||||
|
description: Indicates whether peer inactivity expiration has been enabled or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
approval_required:
|
approval_required:
|
||||||
description: (Cloud only) Indicates whether peer needs approval
|
description: (Cloud only) Indicates whether peer needs approval
|
||||||
type: boolean
|
type: boolean
|
||||||
@@ -354,6 +372,7 @@ components:
|
|||||||
- last_seen
|
- last_seen
|
||||||
- login_expiration_enabled
|
- login_expiration_enabled
|
||||||
- login_expired
|
- login_expired
|
||||||
|
- inactivity_expiration_enabled
|
||||||
- os
|
- os
|
||||||
- ssh_enabled
|
- ssh_enabled
|
||||||
- user_id
|
- user_id
|
||||||
|
|||||||
@@ -220,6 +220,12 @@ type AccountSettings struct {
|
|||||||
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
||||||
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
||||||
|
|
||||||
|
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
|
||||||
|
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
|
||||||
|
|
||||||
|
// PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
|
||||||
|
PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"`
|
||||||
|
|
||||||
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
||||||
PeerLoginExpiration int `json:"peer_login_expiration"`
|
PeerLoginExpiration int `json:"peer_login_expiration"`
|
||||||
|
|
||||||
@@ -538,6 +544,9 @@ type Peer struct {
|
|||||||
// Id Peer ID
|
// Id Peer ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
|
||||||
|
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||||
|
|
||||||
// Ip Peer's IP address
|
// Ip Peer's IP address
|
||||||
Ip string `json:"ip"`
|
Ip string `json:"ip"`
|
||||||
|
|
||||||
@@ -613,6 +622,9 @@ type PeerBatch struct {
|
|||||||
// Id Peer ID
|
// Id Peer ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
|
||||||
|
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||||
|
|
||||||
// Ip Peer's IP address
|
// Ip Peer's IP address
|
||||||
Ip string `json:"ip"`
|
Ip string `json:"ip"`
|
||||||
|
|
||||||
@@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string
|
|||||||
// PeerRequest defines model for PeerRequest.
|
// PeerRequest defines model for PeerRequest.
|
||||||
type PeerRequest struct {
|
type PeerRequest struct {
|
||||||
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
||||||
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
||||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||||
Name string `json:"name"`
|
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
Name string `json:"name"`
|
||||||
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PersonalAccessToken defines model for PersonalAccessToken.
|
// PersonalAccessToken defines model for PersonalAccessToken.
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@@ -14,7 +16,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeersHandler is a handler that returns peers of the account
|
// PeersHandler is a handler that returns peers of the account
|
||||||
@@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
|||||||
SSHEnabled: req.SshEnabled,
|
SSHEnabled: req.SshEnabled,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||||
|
|
||||||
|
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.ApprovalRequired != nil {
|
if req.ApprovalRequired != nil {
|
||||||
@@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &api.Peer{
|
return &api.Peer{
|
||||||
Id: peer.ID,
|
Id: peer.ID,
|
||||||
Name: peer.Name,
|
Name: peer.Name,
|
||||||
Ip: peer.IP.String(),
|
Ip: peer.IP.String(),
|
||||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||||
Connected: peer.Status.Connected,
|
Connected: peer.Status.Connected,
|
||||||
LastSeen: peer.Status.LastSeen,
|
LastSeen: peer.Status.LastSeen,
|
||||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||||
KernelVersion: peer.Meta.KernelVersion,
|
KernelVersion: peer.Meta.KernelVersion,
|
||||||
GeonameId: int(peer.Location.GeoNameID),
|
GeonameId: int(peer.Location.GeoNameID),
|
||||||
Version: peer.Meta.WtVersion,
|
Version: peer.Meta.WtVersion,
|
||||||
Groups: groupsInfo,
|
Groups: groupsInfo,
|
||||||
SshEnabled: peer.SSHEnabled,
|
SshEnabled: peer.SSHEnabled,
|
||||||
Hostname: peer.Meta.Hostname,
|
Hostname: peer.Meta.Hostname,
|
||||||
UserId: peer.UserID,
|
UserId: peer.UserID,
|
||||||
UiVersion: peer.Meta.UIVersion,
|
UiVersion: peer.Meta.UIVersion,
|
||||||
DnsLabel: fqdn(peer, dnsDomain),
|
DnsLabel: fqdn(peer, dnsDomain),
|
||||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||||
LastLogin: peer.LastLogin,
|
LastLogin: peer.LastLogin,
|
||||||
LoginExpired: peer.Status.LoginExpired,
|
LoginExpired: peer.Status.LoginExpired,
|
||||||
ApprovalRequired: !approved,
|
ApprovalRequired: !approved,
|
||||||
CountryCode: peer.Location.CountryCode,
|
CountryCode: peer.Location.CountryCode,
|
||||||
CityName: peer.Location.CityName,
|
CityName: peer.Location.CityName,
|
||||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
|||||||
CountryCode: peer.Location.CountryCode,
|
CountryCode: peer.Location.CountryCode,
|
||||||
CityName: peer.Location.CityName,
|
CityName: peer.Location.CityName,
|
||||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
|
|
||||||
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -110,6 +110,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.AddedWithSSOLogin() {
|
||||||
|
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||||
|
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if expired {
|
||||||
|
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||||
|
// the expired one. Here we notify them that connection is now allowed again.
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
|
||||||
oldStatus := peer.Status.Copy()
|
oldStatus := peer.Status.Copy()
|
||||||
newStatus := oldStatus
|
newStatus := oldStatus
|
||||||
newStatus.LastSeen = time.Now().UTC()
|
newStatus.LastSeen = time.Now().UTC()
|
||||||
@@ -138,25 +163,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
|
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
return oldStatus.LoginExpired, nil
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
if oldStatus.LoginExpired {
|
|
||||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
|
||||||
// the expired one. Here we notify them that connection is now allowed again.
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
@@ -219,6 +234,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
|
||||||
|
|
||||||
|
if !peer.AddedWithSSOLogin() {
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
|
||||||
|
|
||||||
|
event := activity.PeerInactivityExpirationEnabled
|
||||||
|
if !update.InactivityExpirationEnabled {
|
||||||
|
event = activity.PeerInactivityExpirationDisabled
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||||
|
|
||||||
|
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
err = am.Store.SaveAccount(ctx, account)
|
||||||
@@ -442,23 +476,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
|
|
||||||
registrationTime := time.Now().UTC()
|
registrationTime := time.Now().UTC()
|
||||||
newPeer = &nbpeer.Peer{
|
newPeer = &nbpeer.Peer{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Key: peer.Key,
|
Key: peer.Key,
|
||||||
SetupKey: upperKey,
|
SetupKey: upperKey,
|
||||||
IP: freeIP,
|
IP: freeIP,
|
||||||
Meta: peer.Meta,
|
Meta: peer.Meta,
|
||||||
Name: peer.Meta.Hostname,
|
Name: peer.Meta.Hostname,
|
||||||
DNSLabel: freeLabel,
|
DNSLabel: freeLabel,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||||
SSHEnabled: false,
|
SSHEnabled: false,
|
||||||
SSHKey: peer.SSHKey,
|
SSHKey: peer.SSHKey,
|
||||||
LastLogin: registrationTime,
|
LastLogin: registrationTime,
|
||||||
CreatedAt: registrationTime,
|
CreatedAt: registrationTime,
|
||||||
LoginExpirationEnabled: addedByUser,
|
LoginExpirationEnabled: addedByUser,
|
||||||
Ephemeral: ephemeral,
|
Ephemeral: ephemeral,
|
||||||
Location: peer.Location,
|
Location: peer.Location,
|
||||||
|
InactivityExpirationEnabled: addedByUser,
|
||||||
}
|
}
|
||||||
opEvent.TargetID = newPeer.ID
|
opEvent.TargetID = newPeer.ID
|
||||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ type Peer struct {
|
|||||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||||
// Works with LastLogin
|
// Works with LastLogin
|
||||||
LoginExpirationEnabled bool
|
LoginExpirationEnabled bool
|
||||||
|
|
||||||
|
InactivityExpirationEnabled bool
|
||||||
// LastLogin the time when peer performed last login operation
|
// LastLogin the time when peer performed last login operation
|
||||||
LastLogin time.Time
|
LastLogin time.Time
|
||||||
// CreatedAt records the time the peer was created
|
// CreatedAt records the time the peer was created
|
||||||
@@ -187,6 +189,8 @@ func (p *Peer) Copy() *Peer {
|
|||||||
CreatedAt: p.CreatedAt,
|
CreatedAt: p.CreatedAt,
|
||||||
Ephemeral: p.Ephemeral,
|
Ephemeral: p.Ephemeral,
|
||||||
Location: p.Location,
|
Location: p.Location,
|
||||||
|
|
||||||
|
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,6 +223,22 @@ func (p *Peer) MarkLoginExpired(expired bool) {
|
|||||||
p.Status = newStatus
|
p.Status = newStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionExpired indicates whether the peer's session has expired or not.
|
||||||
|
// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired.
|
||||||
|
// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired).
|
||||||
|
// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
|
||||||
|
// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
|
||||||
|
// Only peers added by interactive SSO login can be expired.
|
||||||
|
func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||||
|
if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
expiresAt := p.Status.LastSeen.Add(expiresIn)
|
||||||
|
now := time.Now()
|
||||||
|
timeLeft := expiresAt.Sub(now)
|
||||||
|
return timeLeft <= 0, timeLeft
|
||||||
|
}
|
||||||
|
|
||||||
// LoginExpired indicates whether the peer's login has expired or not.
|
// LoginExpired indicates whether the peer's login has expired or not.
|
||||||
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
||||||
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
||||||
|
|||||||
@@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPeer_SessionExpired(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expirationEnabled bool
|
||||||
|
lastLogin time.Time
|
||||||
|
connected bool
|
||||||
|
expected bool
|
||||||
|
accountSettings *Settings
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
|
||||||
|
expirationEnabled: false,
|
||||||
|
connected: false,
|
||||||
|
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||||
|
accountSettings: &Settings{
|
||||||
|
PeerInactivityExpirationEnabled: true,
|
||||||
|
PeerInactivityExpiration: time.Hour,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peer Inactivity Should Expire",
|
||||||
|
expirationEnabled: true,
|
||||||
|
connected: false,
|
||||||
|
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||||
|
accountSettings: &Settings{
|
||||||
|
PeerInactivityExpirationEnabled: true,
|
||||||
|
PeerInactivityExpiration: time.Second,
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peer Inactivity Should Not Expire",
|
||||||
|
expirationEnabled: true,
|
||||||
|
connected: true,
|
||||||
|
lastLogin: time.Now().UTC(),
|
||||||
|
accountSettings: &Settings{
|
||||||
|
PeerInactivityExpirationEnabled: true,
|
||||||
|
PeerInactivityExpiration: time.Second,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range tt {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
peerStatus := &nbpeer.PeerStatus{
|
||||||
|
Connected: c.connected,
|
||||||
|
}
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
InactivityExpirationEnabled: c.expirationEnabled,
|
||||||
|
LastLogin: c.lastLogin,
|
||||||
|
Status: peerStatus,
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration)
|
||||||
|
assert.Equal(t, expired, c.expected)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||||
|
accountCopy := Account{
|
||||||
|
Domain: domain,
|
||||||
|
DomainCategory: category,
|
||||||
|
IsDomainPrimaryAccount: isPrimaryDomain,
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
||||||
|
result := s.db.WithContext(ctx).Model(&Account{}).
|
||||||
|
Select(fieldsToUpdate).
|
||||||
|
Where(idQueryCondition, accountID).
|
||||||
|
Updates(&accountCopy)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "account %s", accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||||
var peerCopy nbpeer.Peer
|
var peerCopy nbpeer.Peer
|
||||||
peerCopy.Status = &peerStatus
|
peerCopy.Status = &peerStatus
|
||||||
@@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||||
|
var users []*User
|
||||||
|
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||||
@@ -1117,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
|||||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
var group nbgroup.Group
|
var group nbgroup.Group
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||||
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
// we may need to reconsider changing the types.
|
||||||
|
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
|
||||||
|
if s.storeEngine == PostgresStoreEngine {
|
||||||
|
query = query.Order("json_array_length(peers::json) DESC")
|
||||||
|
} else {
|
||||||
|
query = query.Order("json_array_length(peers) DESC")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "group not found")
|
return nil, status.Errorf(status.NotFound, "group not found")
|
||||||
|
|||||||
@@ -1191,3 +1191,76 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
account, err := store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
users, err := store.GetAccountUsers(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, users, len(account.Users))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
t.Run("Should update attributes with public domain", func(t *testing.T) {
|
||||||
|
require.NoError(t, err)
|
||||||
|
domain := "example.com"
|
||||||
|
category := "public"
|
||||||
|
IsDomainPrimaryAccount := false
|
||||||
|
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||||
|
require.NoError(t, err)
|
||||||
|
account, err := store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, domain, account.Domain)
|
||||||
|
require.Equal(t, category, account.DomainCategory)
|
||||||
|
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Should update attributes with private domain", func(t *testing.T) {
|
||||||
|
require.NoError(t, err)
|
||||||
|
domain := "test.com"
|
||||||
|
category := "private"
|
||||||
|
IsDomainPrimaryAccount := true
|
||||||
|
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||||
|
require.NoError(t, err)
|
||||||
|
account, err := store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, domain, account.Domain)
|
||||||
|
require.Equal(t, category, account.DomainCategory)
|
||||||
|
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Should fail when account does not exist", func(t *testing.T) {
|
||||||
|
require.NoError(t, err)
|
||||||
|
domain := "test.com"
|
||||||
|
category := "private"
|
||||||
|
IsDomainPrimaryAccount := true
|
||||||
|
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetGroupByName(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "All", group.Name)
|
||||||
|
}
|
||||||
|
|||||||
@@ -58,9 +58,11 @@ type Store interface {
|
|||||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||||
SaveAccount(ctx context.Context, account *Account) error
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
DeleteAccount(ctx context.Context, account *Account) error
|
DeleteAccount(ctx context.Context, account *Account) error
|
||||||
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
|
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
|
|||||||
@@ -16,8 +16,10 @@ const (
|
|||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
metric.Meter
|
metric.Meter
|
||||||
|
|
||||||
TransferBytesSent metric.Int64Counter
|
TransferBytesSent metric.Int64Counter
|
||||||
TransferBytesRecv metric.Int64Counter
|
TransferBytesRecv metric.Int64Counter
|
||||||
|
AuthenticationTime metric.Float64Histogram
|
||||||
|
PeerStoreTime metric.Float64Histogram
|
||||||
|
|
||||||
peers metric.Int64UpDownCounter
|
peers metric.Int64UpDownCounter
|
||||||
peerActivityChan chan string
|
peerActivityChan chan string
|
||||||
@@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
m := &Metrics{
|
m := &Metrics{
|
||||||
Meter: meter,
|
Meter: meter,
|
||||||
TransferBytesSent: bytesSent,
|
TransferBytesSent: bytesSent,
|
||||||
TransferBytesRecv: bytesRecv,
|
TransferBytesRecv: bytesRecv,
|
||||||
peers: peers,
|
AuthenticationTime: authTime,
|
||||||
|
PeerStoreTime: peerStoreTime,
|
||||||
|
peers: peers,
|
||||||
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
peerActivityChan: make(chan string, 10),
|
peerActivityChan: make(chan string, 10),
|
||||||
@@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
|
|||||||
m.peerLastActive[id] = time.Time{}
|
m.peerLastActive[id] = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordAuthenticationTime measures the time taken for peer authentication
|
||||||
|
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
|
||||||
|
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordPeerStoreTime measures the time to store the peer in map
|
||||||
|
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
|
||||||
|
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||||
|
}
|
||||||
|
|
||||||
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
||||||
func (m *Metrics) PeerDisconnected(id string) {
|
func (m *Metrics) PeerDisconnected(id string) {
|
||||||
m.peers.Add(m.ctx, -1)
|
m.peers.Add(m.ctx, -1)
|
||||||
@@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStandardBucketBoundaries() []float64 {
|
||||||
|
return []float64{
|
||||||
|
0.1,
|
||||||
|
0.5,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
50,
|
||||||
|
100,
|
||||||
|
500,
|
||||||
|
1000,
|
||||||
|
5000,
|
||||||
|
10000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
153
relay/server/handshake.go
Normal file
153
relay/server/handshake.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
//nolint:staticcheck
|
||||||
|
"github.com/netbirdio/netbird/relay/messages/address"
|
||||||
|
//nolint:staticcheck
|
||||||
|
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// preparedMsg contains the marshalled success response messages
|
||||||
|
type preparedMsg struct {
|
||||||
|
responseHelloMsg []byte
|
||||||
|
responseAuthMsg []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
|
||||||
|
rhm, err := marshalResponseHelloMsg(instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ram, err := messages.MarshalAuthResponse(instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &preparedMsg{
|
||||||
|
responseHelloMsg: rhm,
|
||||||
|
responseAuthMsg: ram,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
||||||
|
addr := &address.Address{URL: instanceURL}
|
||||||
|
addrData, err := addr.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal response address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck
|
||||||
|
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
|
||||||
|
}
|
||||||
|
return responseMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshake struct {
|
||||||
|
conn net.Conn
|
||||||
|
validator auth.Validator
|
||||||
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
|
handshakeMethodAuth bool
|
||||||
|
peerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||||
|
buf := make([]byte, messages.MaxHandshakeSize)
|
||||||
|
n, err := h.conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = messages.ValidateVersion(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
bytePeerID []byte
|
||||||
|
peerID string
|
||||||
|
)
|
||||||
|
switch msgType {
|
||||||
|
//nolint:staticcheck
|
||||||
|
case messages.MsgTypeHello:
|
||||||
|
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||||
|
case messages.MsgTypeAuth:
|
||||||
|
h.handshakeMethodAuth = true
|
||||||
|
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h.peerID = peerID
|
||||||
|
return bytePeerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handshakeResponse() error {
|
||||||
|
var responseMsg []byte
|
||||||
|
if h.handshakeMethodAuth {
|
||||||
|
responseMsg = h.preparedMsg.responseAuthMsg
|
||||||
|
} else {
|
||||||
|
responseMsg = h.preparedMsg.responseHelloMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := h.conn.Write(responseMsg); err != nil {
|
||||||
|
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
|
||||||
|
//nolint:staticcheck
|
||||||
|
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID := messages.HashIDToString(rawPeerID)
|
||||||
|
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
||||||
|
|
||||||
|
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck
|
||||||
|
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPeerID, peerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
|
||||||
|
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID := messages.HashIDToString(rawPeerID)
|
||||||
|
|
||||||
|
if err := h.validator.Validate(authPayload); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPeerID, peerID, nil
|
||||||
|
}
|
||||||
@@ -7,16 +7,13 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/auth"
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/messages/address"
|
|
||||||
//nolint:staticcheck
|
|
||||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,6 +25,7 @@ type Relay struct {
|
|||||||
|
|
||||||
store *Store
|
store *Store
|
||||||
instanceURL string
|
instanceURL string
|
||||||
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
closeMu sync.RWMutex
|
closeMu sync.RWMutex
|
||||||
@@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
|
|||||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
return nil, fmt.Errorf("get instance URL: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
metricsCancel()
|
||||||
|
return nil, fmt.Errorf("prepare message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
|||||||
|
|
||||||
// Accept start to handle a new peer connection
|
// Accept start to handle a new peer connection
|
||||||
func (r *Relay) Accept(conn net.Conn) {
|
func (r *Relay) Accept(conn net.Conn) {
|
||||||
|
acceptTime := time.Now()
|
||||||
r.closeMu.RLock()
|
r.closeMu.RLock()
|
||||||
defer r.closeMu.RUnlock()
|
defer r.closeMu.RUnlock()
|
||||||
if r.closed {
|
if r.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID, err := r.handshake(conn)
|
h := handshake{
|
||||||
|
conn: conn,
|
||||||
|
validator: r.validator,
|
||||||
|
preparedMsg: r.preparedMsg,
|
||||||
|
}
|
||||||
|
peerID, err := h.handshakeReceive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to handshake: %s", err)
|
log.Errorf("failed to handshake: %s", err)
|
||||||
cErr := conn.Close()
|
if cErr := conn.Close(); cErr != nil {
|
||||||
if cErr != nil {
|
|
||||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
|
|
||||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
||||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||||
|
storeTime := time.Now()
|
||||||
r.store.AddPeer(peer)
|
r.store.AddPeer(peer)
|
||||||
|
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||||
r.metrics.PeerConnected(peer.String())
|
r.metrics.PeerConnected(peer.String())
|
||||||
go func() {
|
go func() {
|
||||||
peer.Work()
|
peer.Work()
|
||||||
@@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
peer.log.Debugf("relay connection closed")
|
peer.log.Debugf("relay connection closed")
|
||||||
r.metrics.PeerDisconnected(peer.String())
|
r.metrics.PeerDisconnected(peer.String())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if err := h.handshakeResponse(); err != nil {
|
||||||
|
log.Errorf("failed to send handshake response, close peer: %s", err)
|
||||||
|
peer.Close()
|
||||||
|
}
|
||||||
|
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown closes the relay server
|
// Shutdown closes the relay server
|
||||||
@@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
|
|||||||
func (r *Relay) InstanceURL() string {
|
func (r *Relay) InstanceURL() string {
|
||||||
return r.instanceURL
|
return r.instanceURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
|
||||||
buf := make([]byte, messages.MaxHandshakeSize)
|
|
||||||
n, err := conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = messages.ValidateVersion(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
responseMsg []byte
|
|
||||||
peerID []byte
|
|
||||||
)
|
|
||||||
switch msgType {
|
|
||||||
//nolint:staticcheck
|
|
||||||
case messages.MsgTypeHello:
|
|
||||||
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
|
||||||
case messages.MsgTypeAuth:
|
|
||||||
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.Write(responseMsg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peerID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
|
||||||
//nolint:staticcheck
|
|
||||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
|
||||||
|
|
||||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := &address.Address{URL: r.instanceURL}
|
|
||||||
addrData, err := addr.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
return rawPeerID, responseMsg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
|
||||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
|
|
||||||
if err := r.validator.Validate(authPayload); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rawPeerID, responseMsg, nil
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user