mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 15:16:38 +00:00
Compare commits
5 Commits
feature/ap
...
v0.9.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93d20e370b | ||
|
|
878ca6db22 | ||
|
|
2033650908 | ||
|
|
34c1c7d901 | ||
|
|
051fd3a4d7 |
@@ -292,9 +292,6 @@ func isProviderConfigValid(config ProviderConfig) error {
|
|||||||
if config.ClientID == "" {
|
if config.ClientID == "" {
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||||
}
|
}
|
||||||
if config.ClientSecret == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client Secret")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
if config.TokenEndpoint == "" {
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,14 +9,16 @@ import (
|
|||||||
import "github.com/google/nftables"
|
import "github.com/google/nftables"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
|
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
|
||||||
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
|
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
|
||||||
ipv6Nat = "netbird-rt-ipv6-nat"
|
ipv6Nat = "netbird-rt-ipv6-nat"
|
||||||
ipv4Nat = "netbird-rt-ipv4-nat"
|
ipv4Nat = "netbird-rt-ipv4-nat"
|
||||||
natFormat = "netbird-nat-%s"
|
natFormat = "netbird-nat-%s"
|
||||||
forwardingFormat = "netbird-fwd-%s"
|
forwardingFormat = "netbird-fwd-%s"
|
||||||
ipv6 = "ipv6"
|
inNatFormat = "netbird-nat-in-%s"
|
||||||
ipv4 = "ipv4"
|
inForwardingFormat = "netbird-fwd-in-%s"
|
||||||
|
ipv6 = "ipv6"
|
||||||
|
ipv4 = "ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
func genKey(format string, input string) string {
|
func genKey(format string, input string) string {
|
||||||
@@ -53,3 +55,13 @@ func NewFirewall(parentCTX context.Context) firewallManager {
|
|||||||
|
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getInPair(pair routerPair) routerPair {
|
||||||
|
return routerPair{
|
||||||
|
ID: pair.ID,
|
||||||
|
// invert source/destination
|
||||||
|
source: pair.destination,
|
||||||
|
destination: pair.source,
|
||||||
|
masquerade: pair.masquerade,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -311,7 +311,37 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
i.mux.Lock()
|
i.mux.Lock()
|
||||||
defer i.mux.Unlock()
|
defer i.mux.Unlock()
|
||||||
|
|
||||||
|
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pair.masquerade {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertRoutingRule inserts an iptable rule
|
||||||
|
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
prefix := netip.MustParsePrefix(pair.source)
|
||||||
ipVersion := ipv4
|
ipVersion := ipv4
|
||||||
iptablesClient := i.ipv4Client
|
iptablesClient := i.ipv4Client
|
||||||
@@ -320,43 +350,22 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
ipVersion = ipv6
|
ipVersion = ipv6
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardRuleKey := genKey(forwardingFormat, pair.ID)
|
ruleKey := genKey(keyFormat, pair.ID)
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination)
|
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
|
||||||
existingRule, found := i.rules[ipVersion][forwardRuleKey]
|
existingRule, found := i.rules[ipVersion][ruleKey]
|
||||||
if found {
|
if found {
|
||||||
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
|
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
|
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
||||||
}
|
}
|
||||||
delete(i.rules[ipVersion], forwardRuleKey)
|
delete(i.rules[ipVersion], ruleKey)
|
||||||
}
|
}
|
||||||
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
err = iptablesClient.Insert(table, chain, 1, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err)
|
return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.rules[ipVersion][forwardRuleKey] = forwardRule
|
i.rules[ipVersion][ruleKey] = rule
|
||||||
|
|
||||||
if !pair.masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, pair.ID)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination)
|
|
||||||
existingRule, found = i.rules[ipVersion][natRuleKey]
|
|
||||||
if found {
|
|
||||||
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
delete(i.rules[ipVersion], natRuleKey)
|
|
||||||
}
|
|
||||||
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
i.rules[ipVersion][natRuleKey] = natRule
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -366,7 +375,37 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
i.mux.Lock()
|
i.mux.Lock()
|
||||||
defer i.mux.Unlock()
|
defer i.mux.Unlock()
|
||||||
|
|
||||||
|
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pair.masquerade {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoutingRule removes an iptables rule
|
||||||
|
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
prefix := netip.MustParsePrefix(pair.source)
|
||||||
ipVersion := ipv4
|
ipVersion := ipv4
|
||||||
iptablesClient := i.ipv4Client
|
iptablesClient := i.ipv4Client
|
||||||
@@ -375,29 +414,23 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
ipVersion = ipv6
|
ipVersion = ipv6
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardRuleKey := genKey(forwardingFormat, pair.ID)
|
ruleKey := genKey(keyFormat, pair.ID)
|
||||||
existingRule, found := i.rules[ipVersion][forwardRuleKey]
|
existingRule, found := i.rules[ipVersion][ruleKey]
|
||||||
if found {
|
if found {
|
||||||
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
|
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
|
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(i.rules[ipVersion], forwardRuleKey)
|
delete(i.rules[ipVersion], ruleKey)
|
||||||
|
|
||||||
if !pair.masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, pair.ID)
|
|
||||||
existingRule, found = i.rules[ipVersion][natRuleKey]
|
|
||||||
if found {
|
|
||||||
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(i.rules[ipVersion], natRuleKey)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getIptablesRuleType(table string) string {
|
||||||
|
ruleType := "forwarding"
|
||||||
|
if table == iptablesNatTable {
|
||||||
|
ruleType = "nat"
|
||||||
|
}
|
||||||
|
return ruleType
|
||||||
|
}
|
||||||
|
|||||||
@@ -159,6 +159,17 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||||
|
|
||||||
|
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
||||||
|
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
|
require.True(t, exists, "income forwarding rule should exist")
|
||||||
|
|
||||||
|
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
|
||||||
|
require.True(t, found, "income forwarding rule should exist in the manager map")
|
||||||
|
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
||||||
|
|
||||||
@@ -172,7 +183,23 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
} else {
|
} else {
|
||||||
require.False(t, exists, "nat rule should not be created")
|
require.False(t, exists, "nat rule should not be created")
|
||||||
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
|
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
|
||||||
require.False(t, foundNat, "nat rule should exist in the map")
|
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||||
|
}
|
||||||
|
|
||||||
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
||||||
|
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
||||||
|
if testCase.inputPair.masquerade {
|
||||||
|
require.True(t, exists, "income nat rule should be created")
|
||||||
|
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
|
||||||
|
require.True(t, foundNat, "income nat rule should exist in the map")
|
||||||
|
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
||||||
|
} else {
|
||||||
|
require.False(t, exists, "nat rule should not be created")
|
||||||
|
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
|
||||||
|
require.False(t, foundNat, "income nat rule should not exist in the map")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -213,12 +240,24 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
||||||
|
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
||||||
|
|
||||||
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
||||||
|
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
delete(manager.rules, ipv4)
|
delete(manager.rules, ipv4)
|
||||||
delete(manager.rules, ipv6)
|
delete(manager.rules, ipv6)
|
||||||
|
|
||||||
@@ -235,12 +274,26 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
|
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
|
||||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
|
require.False(t, exists, "income forwarding rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
|
||||||
|
require.False(t, found, "income forwarding rule should exist in the manager map")
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
|
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
||||||
require.False(t, exists, "nat rule should not exist")
|
require.False(t, exists, "nat rule should not exist")
|
||||||
|
|
||||||
_, found = manager.rules[testCase.ipVersion][natRuleKey]
|
_, found = manager.rules[testCase.ipVersion][natRuleKey]
|
||||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
require.False(t, found, "nat rule should exist in the manager map")
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
||||||
|
require.False(t, exists, "income nat rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
|
||||||
|
require.False(t, found, "income nat rule should exist in the manager map")
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
)
|
)
|
||||||
import "github.com/google/nftables"
|
import "github.com/google/nftables"
|
||||||
|
|
||||||
//
|
|
||||||
const (
|
const (
|
||||||
nftablesTable = "netbird-rt"
|
nftablesTable = "netbird-rt"
|
||||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
||||||
@@ -248,53 +247,77 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
n.mux.Lock()
|
n.mux.Lock()
|
||||||
defer n.mux.Unlock()
|
defer n.mux.Unlock()
|
||||||
|
|
||||||
|
err := n.refreshRulesMap()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair.masquerade {
|
||||||
|
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = n.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||||
|
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
prefix := netip.MustParsePrefix(pair.source)
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
||||||
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
||||||
|
|
||||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
var expression []expr.Any
|
||||||
fwdKey := genKey(forwardingFormat, pair.ID)
|
if isNat {
|
||||||
if prefix.Addr().Unmap().Is4() {
|
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
||||||
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: n.tableIPv4,
|
|
||||||
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(fwdKey),
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
|
expression = append(sourceExp, append(destExp, exprCounterAccept...)...)
|
||||||
Table: n.tableIPv6,
|
|
||||||
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(fwdKey),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if pair.masquerade {
|
ruleKey := genKey(format, pair.ID)
|
||||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
||||||
natKey := genKey(natFormat, pair.ID)
|
|
||||||
|
|
||||||
if prefix.Addr().Unmap().Is4() {
|
_, exists := n.rules[ruleKey]
|
||||||
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
|
if exists {
|
||||||
Table: n.tableIPv4,
|
err := n.removeRoutingRule(format, pair)
|
||||||
Chain: n.chains[ipv4][nftablesRoutingNatChain],
|
if err != nil {
|
||||||
Exprs: natExp,
|
return err
|
||||||
UserData: []byte(natKey),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: n.tableIPv6,
|
|
||||||
Chain: n.chains[ipv6][nftablesRoutingNatChain],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(natKey),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := n.conn.Flush()
|
if prefix.Addr().Unmap().Is4() {
|
||||||
if err != nil {
|
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
Table: n.tableIPv4,
|
||||||
|
Chain: n.chains[ipv4][chain],
|
||||||
|
Exprs: expression,
|
||||||
|
UserData: []byte(ruleKey),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: n.tableIPv6,
|
||||||
|
Chain: n.chains[ipv6][chain],
|
||||||
|
Exprs: expression,
|
||||||
|
UserData: []byte(ruleKey),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -309,26 +332,26 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fwdKey := genKey(forwardingFormat, pair.ID)
|
err = n.removeRoutingRule(forwardingFormat, pair)
|
||||||
natKey := genKey(natFormat, pair.ID)
|
if err != nil {
|
||||||
fwdRule, found := n.rules[fwdKey]
|
return err
|
||||||
if found {
|
|
||||||
err = n.conn.DelRule(fwdRule)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
log.Debugf("nftables: removing forwarding rule for %s", pair.destination)
|
|
||||||
delete(n.rules, fwdKey)
|
|
||||||
}
|
}
|
||||||
natRule, found := n.rules[natKey]
|
|
||||||
if found {
|
err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
|
||||||
err = n.conn.DelRule(natRule)
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
log.Debugf("nftables: removing nat rule for %s", pair.destination)
|
|
||||||
delete(n.rules, natKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = n.removeRoutingRule(natFormat, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
||||||
@@ -337,6 +360,29 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
||||||
|
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
|
||||||
|
ruleKey := genKey(format, pair.ID)
|
||||||
|
|
||||||
|
rule, found := n.rules[ruleKey]
|
||||||
|
if found {
|
||||||
|
ruleType := "forwarding"
|
||||||
|
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||||
|
ruleType = "nat"
|
||||||
|
}
|
||||||
|
|
||||||
|
err := n.conn.DelRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
|
||||||
|
|
||||||
|
delete(n.rules, ruleKey)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// getPayloadDirectives get expression directives based on ip version and direction
|
// getPayloadDirectives get expression directives based on ip version and direction
|
||||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -189,6 +189,45 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
|
||||||
|
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
|
||||||
|
testingExpression = append(sourceExp, destExp...)
|
||||||
|
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
||||||
|
|
||||||
|
found = 0
|
||||||
|
for _, registeredChains := range manager.chains {
|
||||||
|
for _, chain := range registeredChains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
|
|
||||||
|
if testCase.inputPair.masquerade {
|
||||||
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
||||||
|
found := 0
|
||||||
|
for _, registeredChains := range manager.chains {
|
||||||
|
for _, chain := range registeredChains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -241,6 +280,28 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
UserData: []byte(natRuleKey),
|
UserData: []byte(natRuleKey),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
|
||||||
|
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...)
|
||||||
|
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
||||||
|
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
|
||||||
|
Exprs: forwardExp,
|
||||||
|
UserData: []byte(inForwardRuleKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
||||||
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
||||||
|
|
||||||
|
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
|
||||||
|
Exprs: natExp,
|
||||||
|
UserData: []byte(inNatRuleKey),
|
||||||
|
})
|
||||||
|
|
||||||
err = nftablesTestingClient.Flush()
|
err = nftablesTestingClient.Flush()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
@@ -259,8 +320,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist")
|
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
||||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist")
|
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||||
|
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
||||||
|
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,7 +109,9 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := c.connectToStream(*serverPubKey)
|
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||||
|
defer cancelStream()
|
||||||
|
stream, err := c.connectToStream(ctx, *serverPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to open Management Service stream: %s", err)
|
log.Debugf("failed to open Management Service stream: %s", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||||
@@ -145,7 +147,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
||||||
req := &proto.SyncRequest{}
|
req := &proto.SyncRequest{}
|
||||||
|
|
||||||
myPrivateKey := c.key
|
myPrivateKey := c.key
|
||||||
@@ -156,9 +158,12 @@ func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.Management
|
|||||||
log.Errorf("failed encrypting message: %s", err)
|
log.Errorf("failed encrypting message: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
|
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
|
||||||
return c.realClient.Sync(c.ctx, syncReq)
|
sync, err := c.realClient.Sync(ctx, syncReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return sync, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
|
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -480,7 +481,7 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
|
|||||||
) (*Account, error) {
|
) (*Account, error) {
|
||||||
// if Account ID is part of the claims
|
// if Account ID is part of the claims
|
||||||
// it means that we've already classified the domain and user has an account
|
// it means that we've already classified the domain and user has an account
|
||||||
if claims.DomainCategory != PrivateCategory {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
|
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
|
||||||
} else if claims.AccountId != "" {
|
} else if claims.AccountId != "" {
|
||||||
accountFromID, err := am.GetAccountById(claims.AccountId)
|
accountFromID, err := am.GetAccountById(claims.AccountId)
|
||||||
@@ -520,6 +521,11 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isDomainValid(domain string) bool {
|
||||||
|
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
||||||
|
return re.Match([]byte(domain))
|
||||||
|
}
|
||||||
|
|
||||||
// AccountExists checks whether account exists (returns true) or not (returns false)
|
// AccountExists checks whether account exists (returns true) or not (returns false)
|
||||||
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
|
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
expectedMSG string
|
expectedMSG string
|
||||||
expectedUserRole UserRole
|
expectedUserRole UserRole
|
||||||
expectedDomainCategory string
|
expectedDomainCategory string
|
||||||
|
expectedDomain string
|
||||||
expectedPrimaryDomainStatus bool
|
expectedPrimaryDomainStatus bool
|
||||||
expectedCreatedBy string
|
expectedCreatedBy string
|
||||||
expectedUsers []string
|
expectedUsers []string
|
||||||
@@ -168,6 +169,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
expectedMSG: "account IDs shouldn't match",
|
expectedMSG: "account IDs shouldn't match",
|
||||||
expectedUserRole: UserRoleAdmin,
|
expectedUserRole: UserRoleAdmin,
|
||||||
expectedDomainCategory: "",
|
expectedDomainCategory: "",
|
||||||
|
expectedDomain: publicDomain,
|
||||||
expectedPrimaryDomainStatus: false,
|
expectedPrimaryDomainStatus: false,
|
||||||
expectedCreatedBy: "pub-domain-user",
|
expectedCreatedBy: "pub-domain-user",
|
||||||
expectedUsers: []string{"pub-domain-user"},
|
expectedUsers: []string{"pub-domain-user"},
|
||||||
@@ -188,6 +190,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
testingFunc: require.NotEqual,
|
testingFunc: require.NotEqual,
|
||||||
expectedMSG: "account IDs shouldn't match",
|
expectedMSG: "account IDs shouldn't match",
|
||||||
expectedUserRole: UserRoleAdmin,
|
expectedUserRole: UserRoleAdmin,
|
||||||
|
expectedDomain: unknownDomain,
|
||||||
expectedDomainCategory: "",
|
expectedDomainCategory: "",
|
||||||
expectedPrimaryDomainStatus: false,
|
expectedPrimaryDomainStatus: false,
|
||||||
expectedCreatedBy: "unknown-domain-user",
|
expectedCreatedBy: "unknown-domain-user",
|
||||||
@@ -205,6 +208,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
testingFunc: require.NotEqual,
|
testingFunc: require.NotEqual,
|
||||||
expectedMSG: "account IDs shouldn't match",
|
expectedMSG: "account IDs shouldn't match",
|
||||||
expectedUserRole: UserRoleAdmin,
|
expectedUserRole: UserRoleAdmin,
|
||||||
|
expectedDomain: privateDomain,
|
||||||
expectedDomainCategory: PrivateCategory,
|
expectedDomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: true,
|
expectedPrimaryDomainStatus: true,
|
||||||
expectedCreatedBy: "pvt-domain-user",
|
expectedCreatedBy: "pvt-domain-user",
|
||||||
@@ -227,6 +231,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
testingFunc: require.Equal,
|
testingFunc: require.Equal,
|
||||||
expectedMSG: "account IDs should match",
|
expectedMSG: "account IDs should match",
|
||||||
expectedUserRole: UserRoleUser,
|
expectedUserRole: UserRoleUser,
|
||||||
|
expectedDomain: privateDomain,
|
||||||
expectedDomainCategory: PrivateCategory,
|
expectedDomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: true,
|
expectedPrimaryDomainStatus: true,
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
@@ -244,6 +249,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
testingFunc: require.Equal,
|
testingFunc: require.Equal,
|
||||||
expectedMSG: "account IDs should match",
|
expectedMSG: "account IDs should match",
|
||||||
expectedUserRole: UserRoleAdmin,
|
expectedUserRole: UserRoleAdmin,
|
||||||
|
expectedDomain: defaultInitAccount.Domain,
|
||||||
expectedDomainCategory: PrivateCategory,
|
expectedDomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: true,
|
expectedPrimaryDomainStatus: true,
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
@@ -262,12 +268,32 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
testingFunc: require.Equal,
|
testingFunc: require.Equal,
|
||||||
expectedMSG: "account IDs should match",
|
expectedMSG: "account IDs should match",
|
||||||
expectedUserRole: UserRoleAdmin,
|
expectedUserRole: UserRoleAdmin,
|
||||||
|
expectedDomain: defaultInitAccount.Domain,
|
||||||
expectedDomainCategory: PrivateCategory,
|
expectedDomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: true,
|
expectedPrimaryDomainStatus: true,
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
expectedUsers: []string{defaultInitAccount.UserId},
|
expectedUsers: []string{defaultInitAccount.UserId},
|
||||||
}
|
}
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
|
|
||||||
|
testCase7 := test{
|
||||||
|
name: "User With Private Category And Empty Domain",
|
||||||
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
|
Domain: "",
|
||||||
|
UserId: "pvt-domain-user",
|
||||||
|
DomainCategory: PrivateCategory,
|
||||||
|
},
|
||||||
|
inputInitUserParams: defaultInitAccount,
|
||||||
|
testingFunc: require.NotEqual,
|
||||||
|
expectedMSG: "account IDs shouldn't match",
|
||||||
|
expectedUserRole: UserRoleAdmin,
|
||||||
|
expectedDomain: "",
|
||||||
|
expectedDomainCategory: "",
|
||||||
|
expectedPrimaryDomainStatus: false,
|
||||||
|
expectedCreatedBy: "pvt-domain-user",
|
||||||
|
expectedUsers: []string{"pvt-domain-user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
@@ -294,6 +320,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "expected user role should match")
|
require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "expected user role should match")
|
||||||
require.EqualValues(t, testCase.expectedDomainCategory, account.DomainCategory, "expected account domain category should match")
|
require.EqualValues(t, testCase.expectedDomainCategory, account.DomainCategory, "expected account domain category should match")
|
||||||
require.EqualValues(t, testCase.expectedPrimaryDomainStatus, account.IsDomainPrimaryAccount, "expected account primary status should match")
|
require.EqualValues(t, testCase.expectedPrimaryDomainStatus, account.IsDomainPrimaryAccount, "expected account primary status should match")
|
||||||
|
require.EqualValues(t, testCase.expectedDomain, account.Domain, "expected account domain should match")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,20 +96,18 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/GroupMinimum'
|
$ref: '#/components/schemas/GroupMinimum'
|
||||||
activated_by:
|
|
||||||
description: Provides information of who activated the Peer. User or Setup Key
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
value:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- value
|
|
||||||
ssh_enabled:
|
ssh_enabled:
|
||||||
description: Indicates whether SSH server is enabled on this peer
|
description: Indicates whether SSH server is enabled on this peer
|
||||||
type: boolean
|
type: boolean
|
||||||
|
user_id:
|
||||||
|
description: User ID of the user that enrolled this peer
|
||||||
|
type: string
|
||||||
|
hostname:
|
||||||
|
description: Hostname of the machine
|
||||||
|
type: string
|
||||||
|
ui_version:
|
||||||
|
description: Peer's desktop UI version
|
||||||
|
type: string
|
||||||
required:
|
required:
|
||||||
- ip
|
- ip
|
||||||
- connected
|
- connected
|
||||||
@@ -117,8 +115,8 @@ components:
|
|||||||
- os
|
- os
|
||||||
- version
|
- version
|
||||||
- groups
|
- groups
|
||||||
- activated_by
|
|
||||||
- ssh_enabled
|
- ssh_enabled
|
||||||
|
- hostname
|
||||||
SetupKey:
|
SetupKey:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|||||||
@@ -125,18 +125,15 @@ type PatchMinimumOp string
|
|||||||
|
|
||||||
// Peer defines model for Peer.
|
// Peer defines model for Peer.
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
// Provides information of who activated the Peer. User or Setup Key
|
|
||||||
ActivatedBy struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Value string `json:"value"`
|
|
||||||
} `json:"activated_by"`
|
|
||||||
|
|
||||||
// Peer to Management connection status
|
// Peer to Management connection status
|
||||||
Connected bool `json:"connected"`
|
Connected bool `json:"connected"`
|
||||||
|
|
||||||
// Groups that the peer belongs to
|
// Groups that the peer belongs to
|
||||||
Groups []GroupMinimum `json:"groups"`
|
Groups []GroupMinimum `json:"groups"`
|
||||||
|
|
||||||
|
// Hostname of the machine
|
||||||
|
Hostname string `json:"hostname"`
|
||||||
|
|
||||||
// Peer ID
|
// Peer ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
@@ -155,6 +152,12 @@ type Peer struct {
|
|||||||
// Indicates whether SSH server is enabled on this peer
|
// Indicates whether SSH server is enabled on this peer
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
|
|
||||||
|
// Peer's desktop UI version
|
||||||
|
UiVersion *string `json:"ui_version,omitempty"`
|
||||||
|
|
||||||
|
// User ID of the user that enrolled this peer
|
||||||
|
UserId *string `json:"user_id,omitempty"`
|
||||||
|
|
||||||
// Peer's daemon or cli version
|
// Peer's daemon or cli version
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
//Peers is a handler that returns peers of the account
|
// Peers is a handler that returns peers of the account
|
||||||
type Peers struct {
|
type Peers struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
authAudience string
|
authAudience string
|
||||||
@@ -144,5 +144,8 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
|
|||||||
Version: peer.Meta.WtVersion,
|
Version: peer.Meta.WtVersion,
|
||||||
Groups: groupsInfo,
|
Groups: groupsInfo,
|
||||||
SshEnabled: peer.SSHEnabled,
|
SshEnabled: peer.SSHEnabled,
|
||||||
|
Hostname: peer.Meta.Hostname,
|
||||||
|
UserId: &peer.UserID,
|
||||||
|
UiVersion: &peer.Meta.UIVersion,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//defaultBackoff is a basic backoff mechanism for general issues
|
// defaultBackoff is a basic backoff mechanism for general issues
|
||||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
InitialInterval: 800 * time.Millisecond,
|
InitialInterval: 800 * time.Millisecond,
|
||||||
@@ -121,9 +121,11 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
|||||||
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
|
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to Signal stream identifying ourselves with a public Wireguard key
|
// connect to Signal stream identifying ourselves with a public WireGuard key
|
||||||
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
||||||
stream, err := c.connect(c.key.PublicKey().String())
|
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||||
|
defer cancelStream()
|
||||||
|
stream, err := c.connect(ctx, c.key.PublicKey().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -180,15 +182,13 @@ func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
|
|||||||
return c.connectedCh
|
return c.connectedCh
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||||
c.stream = nil
|
c.stream = nil
|
||||||
|
|
||||||
// add key fingerprint to the request header to be identified on the server side
|
// add key fingerprint to the request header to be identified on the server side
|
||||||
md := metadata.New(map[string]string{proto.HeaderId: key})
|
md := metadata.New(map[string]string{proto.HeaderId: key})
|
||||||
ctx := metadata.NewOutgoingContext(c.ctx, md)
|
metaCtx := metadata.NewOutgoingContext(ctx, md)
|
||||||
|
stream, err := c.realClient.ConnectStream(metaCtx, grpc.WaitForReady(true))
|
||||||
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
|
|
||||||
|
|
||||||
c.stream = stream
|
c.stream = stream
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
Reference in New Issue
Block a user