Compare commits

...

5 Commits

Author SHA1 Message Date
Maycon Santos
93d20e370b Add incoming routing rules (#486)
add an income firewall rule for each routing pair
the pair for the income rule has inverted
source and destination
2022-09-30 14:39:15 +05:00
Maycon Santos
878ca6db22 Check if domain from claim is valid (#485)
If domain is invalid we call GetAccountByUserOrAccountId
2022-09-29 13:51:18 +05:00
braginini
2033650908 Remove IdP client secret validation 2022-09-26 18:58:14 +02:00
Misha Bragin
34c1c7d901 Add hostname, userID, ui version to the HTTP API peer response (#479) 2022-09-26 18:02:45 +02:00
Misha Bragin
051fd3a4d7 Fix Management and Signal gRPC client stream leak (#482) 2022-09-26 18:02:20 +02:00
13 changed files with 397 additions and 151 deletions

View File

@@ -292,9 +292,6 @@ func isProviderConfigValid(config ProviderConfig) error {
if config.ClientID == "" { if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID") return fmt.Errorf(errorMSGFormat, "Client ID")
} }
if config.ClientSecret == "" {
return fmt.Errorf(errorMSGFormat, "Client Secret")
}
if config.TokenEndpoint == "" { if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint") return fmt.Errorf(errorMSGFormat, "Token Endpoint")
} }

View File

@@ -15,6 +15,8 @@ const (
ipv4Nat = "netbird-rt-ipv4-nat" ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s" natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s" forwardingFormat = "netbird-fwd-%s"
inNatFormat = "netbird-nat-in-%s"
inForwardingFormat = "netbird-fwd-in-%s"
ipv6 = "ipv6" ipv6 = "ipv6"
ipv4 = "ipv4" ipv4 = "ipv4"
) )
@@ -53,3 +55,13 @@ func NewFirewall(parentCTX context.Context) firewallManager {
return manager return manager
} }
func getInPair(pair routerPair) routerPair {
return routerPair{
ID: pair.ID,
// invert source/destination
source: pair.destination,
destination: pair.source,
masquerade: pair.masquerade,
}
}

View File

@@ -311,7 +311,37 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
i.mux.Lock() i.mux.Lock()
defer i.mux.Unlock() defer i.mux.Unlock()
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
var err error var err error
prefix := netip.MustParsePrefix(pair.source) prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4 ipVersion := ipv4
iptablesClient := i.ipv4Client iptablesClient := i.ipv4Client
@@ -320,43 +350,22 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
ipVersion = ipv6 ipVersion = ipv6
} }
forwardRuleKey := genKey(forwardingFormat, pair.ID) ruleKey := genKey(keyFormat, pair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination) rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][forwardRuleKey] existingRule, found := i.rules[ipVersion][ruleKey]
if found { if found {
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...) err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err) return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
} }
delete(i.rules[ipVersion], forwardRuleKey) delete(i.rules[ipVersion], ruleKey)
} }
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) err = iptablesClient.Insert(table, chain, 1, rule...)
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err) return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
} }
i.rules[ipVersion][forwardRuleKey] = forwardRule i.rules[ipVersion][ruleKey] = rule
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err)
}
delete(i.rules[ipVersion], natRuleKey)
}
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err)
}
i.rules[ipVersion][natRuleKey] = natRule
return nil return nil
} }
@@ -366,7 +375,37 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
i.mux.Lock() i.mux.Lock()
defer i.mux.Unlock() defer i.mux.Unlock()
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
if err != nil {
return err
}
return nil
}
// removeRoutingRule removes an iptables rule
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
var err error var err error
prefix := netip.MustParsePrefix(pair.source) prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4 ipVersion := ipv4
iptablesClient := i.ipv4Client iptablesClient := i.ipv4Client
@@ -375,29 +414,23 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
ipVersion = ipv6 ipVersion = ipv6
} }
forwardRuleKey := genKey(forwardingFormat, pair.ID) ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][forwardRuleKey] existingRule, found := i.rules[ipVersion][ruleKey]
if found { if found {
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...) err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err) return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
} }
} }
delete(i.rules[ipVersion], forwardRuleKey) delete(i.rules[ipVersion], ruleKey)
if !pair.masquerade {
return nil return nil
} }
natRuleKey := genKey(natFormat, pair.ID) func getIptablesRuleType(table string) string {
existingRule, found = i.rules[ipVersion][natRuleKey] ruleType := "forwarding"
if found { if table == iptablesNatTable {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...) ruleType = "nat"
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err)
} }
} return ruleType
delete(i.rules[ipVersion], natRuleKey)
return nil
} }

View File

@@ -159,6 +159,17 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.True(t, found, "forwarding rule should exist in the manager map") require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := genKey(natFormat, testCase.inputPair.ID) natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
@@ -172,7 +183,23 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
} else { } else {
require.False(t, exists, "nat rule should not be created") require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey] _, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, foundNat, "nat rule should exist in the map") require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
} }
}) })
} }
@@ -213,12 +240,24 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := genKey(natFormat, testCase.inputPair.ID) natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...) err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4) delete(manager.rules, ipv4)
delete(manager.rules, ipv6) delete(manager.rules, ipv6)
@@ -235,12 +274,26 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
_, found := manager.rules[testCase.ipVersion][forwardRuleKey] _, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map") require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...) exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "nat rule should not exist") require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][natRuleKey] _, found = manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map") require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
}) })
} }

View File

@@ -12,7 +12,6 @@ import (
) )
import "github.com/google/nftables" import "github.com/google/nftables"
//
const ( const (
nftablesTable = "netbird-rt" nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd" nftablesRoutingForwardingChain = "netbird-rt-fwd"
@@ -248,53 +247,77 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
n.mux.Lock() n.mux.Lock()
defer n.mux.Unlock() defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
if err != nil {
return err
}
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
if err != nil {
return err
}
if pair.masquerade {
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
if err != nil {
return err
}
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
if err != nil {
return err
}
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
prefix := netip.MustParsePrefix(pair.source) prefix := netip.MustParsePrefix(pair.source)
sourceExp := generateCIDRMatcherExpressions("source", pair.source) sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination) destExp := generateCIDRMatcherExpressions("destination", pair.destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) var expression []expr.Any
fwdKey := genKey(forwardingFormat, pair.ID) if isNat {
if prefix.Addr().Unmap().Is4() { expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
} else { } else {
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ expression = append(sourceExp, append(destExp, exprCounterAccept...)...)
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
} }
if pair.masquerade { ruleKey := genKey(format, pair.ID)
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
natKey := genKey(natFormat, pair.ID)
if prefix.Addr().Unmap().Is4() { _, exists := n.rules[ruleKey]
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ if exists {
Table: n.tableIPv4, err := n.removeRoutingRule(format, pair)
Chain: n.chains[ipv4][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
} else {
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
}
}
err := n.conn.Flush()
if err != nil { if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err) return err
}
}
if prefix.Addr().Unmap().Is4() {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
} else {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
} }
return nil return nil
} }
@@ -309,26 +332,26 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return err return err
} }
fwdKey := genKey(forwardingFormat, pair.ID) err = n.removeRoutingRule(forwardingFormat, pair)
natKey := genKey(natFormat, pair.ID)
fwdRule, found := n.rules[fwdKey]
if found {
err = n.conn.DelRule(fwdRule)
if err != nil { if err != nil {
return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err) return err
} }
log.Debugf("nftables: removing forwarding rule for %s", pair.destination)
delete(n.rules, fwdKey) err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
}
natRule, found := n.rules[natKey]
if found {
err = n.conn.DelRule(natRule)
if err != nil { if err != nil {
return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err) return err
} }
log.Debugf("nftables: removing nat rule for %s", pair.destination)
delete(n.rules, natKey) err = n.removeRoutingRule(natFormat, pair)
if err != nil {
return err
} }
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
if err != nil {
return err
}
err = n.conn.Flush() err = n.conn.Flush()
if err != nil { if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err) return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
@@ -337,6 +360,29 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return nil return nil
} }
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
ruleKey := genKey(format, pair.ID)
rule, found := n.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := n.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
delete(n.rules, ruleKey)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction // getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch { switch {

View File

@@ -189,6 +189,45 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
} }
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
testingExpression = append(sourceExp, destExp...)
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
found = 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
}) })
} }
} }
@@ -241,6 +280,28 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
UserData: []byte(natRuleKey), UserData: []byte(natRuleKey),
}) })
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...)
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush() err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
@@ -259,8 +320,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist") require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist") require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
} }
} }
} }

View File

@@ -109,7 +109,9 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return err return err
} }
stream, err := c.connectToStream(*serverPubKey) ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil { if err != nil {
log.Debugf("failed to open Management Service stream: %s", err) log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied { if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -145,7 +147,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil return nil
} }
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{} req := &proto.SyncRequest{}
myPrivateKey := c.key myPrivateKey := c.key
@@ -156,9 +158,12 @@ func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.Management
log.Errorf("failed encrypting message: %s", err) log.Errorf("failed encrypting message: %s", err)
return nil, err return nil, err
} }
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq} syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
return c.realClient.Sync(c.ctx, syncReq) sync, err := c.realClient.Sync(ctx, syncReq)
if err != nil {
return nil, err
}
return sync, nil
} }
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error { func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {

View File

@@ -15,6 +15,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"math/rand" "math/rand"
"reflect" "reflect"
"regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -480,7 +481,7 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
) (*Account, error) { ) (*Account, error) {
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId) accountFromID, err := am.GetAccountById(claims.AccountId)
@@ -520,6 +521,11 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
} }
} }
func isDomainValid(domain string) bool {
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
return re.Match([]byte(domain))
}
// AccountExists checks whether account exists (returns true) or not (returns false) // AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) { func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
am.mux.Lock() am.mux.Lock()

View File

@@ -140,6 +140,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
expectedMSG string expectedMSG string
expectedUserRole UserRole expectedUserRole UserRole
expectedDomainCategory string expectedDomainCategory string
expectedDomain string
expectedPrimaryDomainStatus bool expectedPrimaryDomainStatus bool
expectedCreatedBy string expectedCreatedBy string
expectedUsers []string expectedUsers []string
@@ -168,6 +169,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
expectedMSG: "account IDs shouldn't match", expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomainCategory: "", expectedDomainCategory: "",
expectedDomain: publicDomain,
expectedPrimaryDomainStatus: false, expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pub-domain-user", expectedCreatedBy: "pub-domain-user",
expectedUsers: []string{"pub-domain-user"}, expectedUsers: []string{"pub-domain-user"},
@@ -188,6 +190,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.NotEqual, testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match", expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomain: unknownDomain,
expectedDomainCategory: "", expectedDomainCategory: "",
expectedPrimaryDomainStatus: false, expectedPrimaryDomainStatus: false,
expectedCreatedBy: "unknown-domain-user", expectedCreatedBy: "unknown-domain-user",
@@ -205,6 +208,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.NotEqual, testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match", expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory, expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, expectedPrimaryDomainStatus: true,
expectedCreatedBy: "pvt-domain-user", expectedCreatedBy: "pvt-domain-user",
@@ -227,6 +231,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal, testingFunc: require.Equal,
expectedMSG: "account IDs should match", expectedMSG: "account IDs should match",
expectedUserRole: UserRoleUser, expectedUserRole: UserRoleUser,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory, expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId, expectedCreatedBy: defaultInitAccount.UserId,
@@ -244,6 +249,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal, testingFunc: require.Equal,
expectedMSG: "account IDs should match", expectedMSG: "account IDs should match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory, expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId, expectedCreatedBy: defaultInitAccount.UserId,
@@ -262,12 +268,32 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal, testingFunc: require.Equal,
expectedMSG: "account IDs should match", expectedMSG: "account IDs should match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory, expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId, expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId}, expectedUsers: []string{defaultInitAccount.UserId},
} }
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
testCase7 := test{
name: "User With Private Category And Empty Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomain: "",
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
@@ -294,6 +320,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "expected user role should match") require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "expected user role should match")
require.EqualValues(t, testCase.expectedDomainCategory, account.DomainCategory, "expected account domain category should match") require.EqualValues(t, testCase.expectedDomainCategory, account.DomainCategory, "expected account domain category should match")
require.EqualValues(t, testCase.expectedPrimaryDomainStatus, account.IsDomainPrimaryAccount, "expected account primary status should match") require.EqualValues(t, testCase.expectedPrimaryDomainStatus, account.IsDomainPrimaryAccount, "expected account primary status should match")
require.EqualValues(t, testCase.expectedDomain, account.Domain, "expected account domain should match")
}) })
} }
} }

View File

@@ -96,20 +96,18 @@ components:
type: array type: array
items: items:
$ref: '#/components/schemas/GroupMinimum' $ref: '#/components/schemas/GroupMinimum'
activated_by:
description: Provides information of who activated the Peer. User or Setup Key
type: object
properties:
type:
type: string
value:
type: string
required:
- type
- value
ssh_enabled: ssh_enabled:
description: Indicates whether SSH server is enabled on this peer description: Indicates whether SSH server is enabled on this peer
type: boolean type: boolean
user_id:
description: User ID of the user that enrolled this peer
type: string
hostname:
description: Hostname of the machine
type: string
ui_version:
description: Peer's desktop UI version
type: string
required: required:
- ip - ip
- connected - connected
@@ -117,8 +115,8 @@ components:
- os - os
- version - version
- groups - groups
- activated_by
- ssh_enabled - ssh_enabled
- hostname
SetupKey: SetupKey:
type: object type: object
properties: properties:

View File

@@ -125,18 +125,15 @@ type PatchMinimumOp string
// Peer defines model for Peer. // Peer defines model for Peer.
type Peer struct { type Peer struct {
// Provides information of who activated the Peer. User or Setup Key
ActivatedBy struct {
Type string `json:"type"`
Value string `json:"value"`
} `json:"activated_by"`
// Peer to Management connection status // Peer to Management connection status
Connected bool `json:"connected"` Connected bool `json:"connected"`
// Groups that the peer belongs to // Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"` Groups []GroupMinimum `json:"groups"`
// Hostname of the machine
Hostname string `json:"hostname"`
// Peer ID // Peer ID
Id string `json:"id"` Id string `json:"id"`
@@ -155,6 +152,12 @@ type Peer struct {
// Indicates whether SSH server is enabled on this peer // Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"` SshEnabled bool `json:"ssh_enabled"`
// Peer's desktop UI version
UiVersion *string `json:"ui_version,omitempty"`
// User ID of the user that enrolled this peer
UserId *string `json:"user_id,omitempty"`
// Peer's daemon or cli version // Peer's daemon or cli version
Version string `json:"version"` Version string `json:"version"`
} }

View File

@@ -144,5 +144,8 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
Version: peer.Meta.WtVersion, Version: peer.Meta.WtVersion,
Groups: groupsInfo, Groups: groupsInfo,
SshEnabled: peer.SSHEnabled, SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: &peer.UserID,
UiVersion: &peer.Meta.UIVersion,
} }
} }

View File

@@ -121,9 +121,11 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
return fmt.Errorf("connection to signal is not ready and in %s state", connState) return fmt.Errorf("connection to signal is not ready and in %s state", connState)
} }
// connect to Signal stream identifying ourselves with a public Wireguard key // connect to Signal stream identifying ourselves with a public WireGuard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management) // todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
stream, err := c.connect(c.key.PublicKey().String()) ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connect(ctx, c.key.PublicKey().String())
if err != nil { if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
return err return err
@@ -180,15 +182,13 @@ func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
return c.connectedCh return c.connectedCh
} }
func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil c.stream = nil
// add key fingerprint to the request header to be identified on the server side // add key fingerprint to the request header to be identified on the server side
md := metadata.New(map[string]string{proto.HeaderId: key}) md := metadata.New(map[string]string{proto.HeaderId: key})
ctx := metadata.NewOutgoingContext(c.ctx, md) metaCtx := metadata.NewOutgoingContext(ctx, md)
stream, err := c.realClient.ConnectStream(metaCtx, grpc.WaitForReady(true))
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
c.stream = stream c.stream = stream
if err != nil { if err != nil {
return nil, err return nil, err