Compare commits

...

14 Commits

Author SHA1 Message Date
Maycon Santos
93d20e370b Add incoming routing rules (#486)
add an income firewall rule for each routing pair
the pair for the income rule has inverted
source and destination
2022-09-30 14:39:15 +05:00
Maycon Santos
878ca6db22 Check if domain from claim is valid (#485)
If domain is invalid we call GetAccountByUserOrAccountId
2022-09-29 13:51:18 +05:00
braginini
2033650908 Remove IdP client secret validation 2022-09-26 18:58:14 +02:00
Misha Bragin
34c1c7d901 Add hostname, userID, ui version to the HTTP API peer response (#479) 2022-09-26 18:02:45 +02:00
Misha Bragin
051fd3a4d7 Fix Management and Signal gRPC client stream leak (#482) 2022-09-26 18:02:20 +02:00
Misha Bragin
af69a48745 Support user role update (#478) 2022-09-23 14:18:42 +02:00
Maycon Santos
68ff97ba84 Parse and received provider proper error message (#476) 2022-09-23 14:18:29 +02:00
Misha Bragin
c5705803a5 Output plain NetBird IPv4 in status command (#474) 2022-09-22 09:25:52 +02:00
braginini
7e1ae448e0 Add extra logging to Sync and Login requests 2022-09-22 09:25:31 +02:00
Misha Bragin
518a2561a2 Add auto-assign groups to the User API (#467) 2022-09-22 09:06:32 +02:00
Maycon Santos
c75ffd0f4b Update ICE library (#471) 2022-09-20 11:40:18 +02:00
Maycon Santos
e4ad6174ca Improve module load (#470)
* Add additional check for needed kernel modules

* Check if wireguard and tun modules are loaded

If modules are loaded return true, otherwise attempt to load them

* fix state check

* Add module function tests

* Add test execution in container

* run client package tests on docker

* add package comment to new file

* force entrypoint

* add --privileged flag

* clean only if tables where created

* run from within the directories
2022-09-15 01:26:11 +05:00
Misha Bragin
6de313070a Always return empty auto_groups if previously were nil (#468) 2022-09-13 17:19:03 +02:00
Misha Bragin
cd7d1a80c9 Assign groups to peers when registering with the setup key (#466) 2022-09-13 13:39:46 +02:00
30 changed files with 1453 additions and 275 deletions

View File

@@ -33,3 +33,55 @@ jobs:
- name: Test
run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: 1.18.x
- name: Cache Go modules
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v2
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libappindicator3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: Install modules
run: go mod tidy
- name: Generate Iface Test bin
run: go test -c -o iface-testing.bin ./iface/...
- name: Generate RouteManager Test bin
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
- name: Generate Engine Test bin
run: go test -c -o engine-testing.bin ./client/internal/*.go
- name: Generate Peer Test bin
run: go test -c -o peer-testing.bin ./client/internal/peer/...
- run: chmod +x *testing.bin
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin
- name: Run Engine tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/util"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"net"
"net/netip"
"sort"
"strings"
@@ -18,6 +19,7 @@ import (
var (
detailFlag bool
ipv4Flag bool
ipsFilter []string
statusFilter string
ipsFilterMap map[string]struct{}
@@ -73,7 +75,7 @@ var statusCmd = &cobra.Command{
pbFullStatus := resp.GetFullStatus()
fullStatus := fromProtoFullStatus(pbFullStatus)
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion()))
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion(), ipv4Flag))
return nil
},
@@ -82,8 +84,9 @@ var statusCmd = &cobra.Command{
func init() {
ipsFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g. --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g. --filter-by-status connected")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
}
func parseFilters() error {
@@ -142,7 +145,19 @@ func fromProtoFullStatus(pbFullStatus *proto.FullStatus) nbStatus.FullStatus {
return fullStatus
}
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string) string {
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string, flag bool) string {
interfaceIP := fullStatus.LocalPeerState.IP
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {
return ""
}
if ipv4Flag {
return fmt.Sprintf("%s\n", ip)
}
var (
managementStatusURL = ""
signalStatusURL = ""
@@ -164,8 +179,6 @@ func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonSta
signalConnString = "Connected"
}
interfaceIP := fullStatus.LocalPeerState.IP
if fullStatus.LocalPeerState.KernelInterface {
interfaceTypeString = "Kernel"
} else if fullStatus.LocalPeerState.IP == "" {

View File

@@ -263,7 +263,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
}
}
return DeviceAuthorizationFlow{
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: ProviderConfig{
@@ -274,5 +274,29 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
},
}, nil
}
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isProviderConfigValid(config ProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
return nil
}

View File

@@ -107,7 +107,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
localPeerState := nbStatus.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModExists(),
KernelInterface: iface.WireguardModuleIsLoaded(),
}
statusRecorder.UpdateLocalPeerState(localPeerState)

View File

@@ -9,14 +9,16 @@ import (
import "github.com/google/nftables"
const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
ipv6Nat = "netbird-rt-ipv6-nat"
ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s"
ipv6 = "ipv6"
ipv4 = "ipv4"
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
ipv6Nat = "netbird-rt-ipv6-nat"
ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s"
inNatFormat = "netbird-nat-in-%s"
inForwardingFormat = "netbird-fwd-in-%s"
ipv6 = "ipv6"
ipv4 = "ipv4"
)
func genKey(format string, input string) string {
@@ -53,3 +55,13 @@ func NewFirewall(parentCTX context.Context) firewallManager {
return manager
}
func getInPair(pair routerPair) routerPair {
return routerPair{
ID: pair.ID,
// invert source/destination
source: pair.destination,
destination: pair.source,
masquerade: pair.masquerade,
}
}

View File

@@ -311,7 +311,37 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
@@ -320,43 +350,22 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
ipVersion = ipv6
}
forwardRuleKey := genKey(forwardingFormat, pair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][forwardRuleKey]
ruleKey := genKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][ruleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
delete(i.rules[ipVersion], forwardRuleKey)
delete(i.rules[ipVersion], ruleKey)
}
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
err = iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err)
return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
i.rules[ipVersion][forwardRuleKey] = forwardRule
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err)
}
delete(i.rules[ipVersion], natRuleKey)
}
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err)
}
i.rules[ipVersion][natRuleKey] = natRule
i.rules[ipVersion][ruleKey] = rule
return nil
}
@@ -366,7 +375,37 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
if err != nil {
return err
}
return nil
}
// removeRoutingRule removes an iptables rule
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
@@ -375,29 +414,23 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
ipVersion = ipv6
}
forwardRuleKey := genKey(forwardingFormat, pair.ID)
existingRule, found := i.rules[ipVersion][forwardRuleKey]
ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][ruleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
}
delete(i.rules[ipVersion], forwardRuleKey)
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err)
}
}
delete(i.rules[ipVersion], natRuleKey)
delete(i.rules[ipVersion], ruleKey)
return nil
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == iptablesNatTable {
ruleType = "nat"
}
return ruleType
}

View File

@@ -159,6 +159,17 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
@@ -172,7 +183,23 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, foundNat, "nat rule should exist in the map")
require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
@@ -213,12 +240,24 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4)
delete(manager.rules, ipv6)
@@ -235,12 +274,26 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}

View File

@@ -12,7 +12,6 @@ import (
)
import "github.com/google/nftables"
//
const (
nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd"
@@ -84,8 +83,10 @@ func (n *nftablesManager) CleanRoutingRules() {
n.mux.Lock()
defer n.mux.Unlock()
log.Debug("flushing tables")
n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4)
if n.tableIPv4 != nil && n.tableIPv6 != nil {
n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4)
}
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
}
@@ -246,53 +247,77 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
n.mux.Lock()
defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
if err != nil {
return err
}
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
if err != nil {
return err
}
if pair.masquerade {
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
if err != nil {
return err
}
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
if err != nil {
return err
}
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
prefix := netip.MustParsePrefix(pair.source)
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
fwdKey := genKey(forwardingFormat, pair.ID)
if prefix.Addr().Unmap().Is4() {
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
} else {
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
expression = append(sourceExp, append(destExp, exprCounterAccept...)...)
}
if pair.masquerade {
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
natKey := genKey(natFormat, pair.ID)
ruleKey := genKey(format, pair.ID)
if prefix.Addr().Unmap().Is4() {
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
} else {
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
_, exists := n.rules[ruleKey]
if exists {
err := n.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
err := n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
if prefix.Addr().Unmap().Is4() {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
} else {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
}
return nil
}
@@ -307,26 +332,26 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return err
}
fwdKey := genKey(forwardingFormat, pair.ID)
natKey := genKey(natFormat, pair.ID)
fwdRule, found := n.rules[fwdKey]
if found {
err = n.conn.DelRule(fwdRule)
if err != nil {
return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removing forwarding rule for %s", pair.destination)
delete(n.rules, fwdKey)
err = n.removeRoutingRule(forwardingFormat, pair)
if err != nil {
return err
}
natRule, found := n.rules[natKey]
if found {
err = n.conn.DelRule(natRule)
if err != nil {
return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removing nat rule for %s", pair.destination)
delete(n.rules, natKey)
err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
if err != nil {
return err
}
err = n.removeRoutingRule(natFormat, pair)
if err != nil {
return err
}
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
if err != nil {
return err
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
@@ -335,6 +360,29 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
ruleKey := genKey(format, pair.ID)
rule, found := n.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := n.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
delete(n.rules, ruleKey)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch {

View File

@@ -189,6 +189,45 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
testingExpression = append(sourceExp, destExp...)
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
found = 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
@@ -241,6 +280,28 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...)
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
@@ -259,8 +320,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist")
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
}

12
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7 //keep this version otherwise wiretrustee up command breaks
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.18.1
github.com/pion/ice/v2 v2.1.17
github.com/pion/ice/v2 v2.2.7
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.3.0
@@ -42,7 +42,7 @@ require (
github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.7.1
golang.org/x/net v0.0.0-20220513224357-95641704303c
golang.org/x/net v0.0.0-20220630215102-69896b714898
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
)
@@ -80,13 +80,13 @@ require (
github.com/nxadm/tail v1.4.8 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.1.2 // indirect
github.com/pion/dtls/v2 v2.1.5 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.5 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/stun v0.3.5 // indirect
github.com/pion/transport v0.13.0 // indirect
github.com/pion/turn/v2 v2.0.7 // indirect
github.com/pion/transport v0.13.1 // indirect
github.com/pion/turn/v2 v2.0.8 // indirect
github.com/pion/udp v0.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.12.2 // indirect
@@ -117,6 +117,4 @@ require (
k8s.io/apimachinery v0.23.5 // indirect
)
replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84

25
go.sum
View File

@@ -505,8 +505,10 @@ github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTK
github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4=
github.com/pegasus-kv/thrift v0.13.0/go.mod h1:Gl9NT/WHG6ABm6NsrbfE8LiJN0sAyneCrvB4qN4NPqQ=
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pion/dtls/v2 v2.1.2 h1:22Q1Jk9L++Yo7BIf9130MonNPfPVb+YgdYLeyQotuAA=
github.com/pion/dtls/v2 v2.1.2/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus=
github.com/pion/dtls/v2 v2.1.5 h1:jlh2vtIyUBShchoTDqpCCqiYCyRFJ/lvf/gQ8TALs+c=
github.com/pion/dtls/v2 v2.1.5/go.mod h1:BqCE7xPZbPSubGasRoDFJeTsyJtdD1FanJYL0JGheqY=
github.com/pion/ice/v2 v2.2.7 h1:kG9tux3WdYUSqqqnf+O5zKlpy41PdlvLUBlYJeV2emQ=
github.com/pion/ice/v2 v2.2.7/go.mod h1:Ckj7cWZ717rtU01YoDQA9ntGWCk95D42uVZ8sI0EL+8=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw=
@@ -516,10 +518,11 @@ github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TB
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/transport v0.13.0 h1:KWTA5ZrQogizzYwPEciGtHPLwpAjE91FgXnyu+Hv2uY=
github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g=
github.com/pion/turn/v2 v2.0.7 h1:SZhc00WDovK6czaN1RSiHqbwANtIO6wfZQsU0m0KNE8=
github.com/pion/turn/v2 v2.0.7/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
github.com/pion/transport v0.13.1 h1:/UH5yLeQtwm2VZIPjxwnNFxjS4DFhyLfS4GlfuKUzfA=
github.com/pion/transport v0.13.1/go.mod h1:EBxbqzyv+ZrmDb82XswEE0BjfQFtuw1Nu6sjnjWCsGg=
github.com/pion/turn/v2 v2.0.8 h1:KEstL92OUN3k5k8qxsXHpr7WWfrdp7iJZHx99ud8muw=
github.com/pion/turn/v2 v2.0.8/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
github.com/pion/udp v0.1.1 h1:8UAPvyqmsxK8oOjloDk4wUt63TzFe9WEJkg5lChlj7o=
github.com/pion/udp v0.1.1/go.mod h1:6AFo+CMdKQm7UiA0eUPA8/eVCTx8jBIITLZHc9DWX5M=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
@@ -624,8 +627,6 @@ github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJ
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb h1:CU1/+CEeCPvYXgfAyqTJXSQSf6hW3wsWM6Dfz6HkHEQ=
github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb/go.mod h1:XT1Nrb4OxbVFPffbQMbq4PaeEkpRLVzdphh3fjrw7DY=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -662,7 +663,7 @@ golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c=
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -772,8 +773,10 @@ golang.org/x/net v0.0.0-20211208012354-db4efeb81f4b/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220513224357-95641704303c h1:nF9mHSvoKBLkQNQhJZNsc66z2UzAMUbLGjC95CF3pU0=
golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw=
golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -906,6 +909,8 @@ golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211214234402-4825e8c3871d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

View File

@@ -34,7 +34,7 @@ func (w *WGIface) assignAddr() error {
return nil
}
// WireguardModExists check if we can load wireguard mod (linux only)
func WireguardModExists() bool {
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
return false
}

View File

@@ -1,48 +1,29 @@
package iface
import (
"errors"
"math"
"os"
"syscall"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"os"
)
type NativeLink struct {
Link *netlink.Link
}
// WireguardModExists check if we can load wireguard mod (linux only)
func WireguardModExists() bool {
link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid
// MTU here as the validation of the MTU will be performed after
// the validation of the link kind and hence allows us to check
// for the existance of the wireguard module without actually
// creating a link.
//
// As a side-effect, this will also let the kernel lazy-load
// the wireguard module.
link.attrs.MTU = math.MaxInt
err := netlink.LinkAdd(link)
return errors.Is(err, syscall.EINVAL)
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
if WireguardModExists() {
if WireguardModuleIsLoaded() {
log.Info("using kernel WireGuard")
return w.createWithKernel()
} else {
if !tunModuleIsLoaded() {
return fmt.Errorf("couldn't check or load tun module")
}
log.Info("using userspace WireGuard")
return w.createWithUserspace()
}

View File

@@ -58,7 +58,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
return w.assignAddr(luid)
}
// WireguardModExists check if we can load wireguard mod (linux only)
func WireguardModExists() bool {
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
return false
}

349
iface/module_linux.go Normal file
View File

@@ -0,0 +1,349 @@
// Package iface provides wireguard network interface creation and management
package iface
import (
"bufio"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"io/fs"
"io/ioutil"
"math"
"os"
"path/filepath"
"strings"
"syscall"
)
// Holds logic to check existence of kernel modules used by wireguard interfaces
// Copied from https://github.com/paultag/go-modprobe and
// https://github.com/pmorjan/kmod
type status int
const (
defaultModuleDir = "/lib/modules"
unknown status = iota
unloaded
unloading
loading
live
inuse
)
type module struct {
name string
path string
}
var (
// ErrModuleNotFound is the error resulting if a module can't be found.
ErrModuleNotFound = errors.New("module not found")
moduleLibDir = defaultModuleDir
// get the root directory for the kernel modules. If this line panics,
// it's because getModuleRoot has failed to get the uname of the running
// kernel (likely a non-POSIX system, but maybe a broken kernel?)
moduleRoot = getModuleRoot()
)
// Get the module root (/lib/modules/$(uname -r)/)
func getModuleRoot() string {
uname := unix.Utsname{}
if err := unix.Uname(&uname); err != nil {
panic(err)
}
i := 0
for ; uname.Release[i] != 0; i++ {
}
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
}
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
func tunModuleIsLoaded() bool {
_, err := os.Stat("/dev/net/tun")
if err == nil {
return true
}
log.Infof("couldn't access device /dev/net/tun, go error %v, "+
"will attempt to load tun module, if running on container add flag --cap-add=NET_ADMIN", err)
tunLoaded, err := tryToLoadModule("tun")
if err != nil {
log.Errorf("unable to find or load tun module, got error: %v", err)
}
return tunLoaded
}
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
if canCreateFakeWireguardInterface() {
return true
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
return loaded
}
func canCreateFakeWireguardInterface() bool {
link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid
// MTU here as the validation of the MTU will be performed after
// the validation of the link kind and hence allows us to check
// for the existance of the wireguard module without actually
// creating a link.
//
// As a side-effect, this will also let the kernel lazy-load
// the wireguard module.
link.attrs.MTU = math.MaxInt
err := netlink.LinkAdd(link)
return errors.Is(err, syscall.EINVAL)
}
func tryToLoadModule(moduleName string) (bool, error) {
if isModuleEnabled(moduleName) {
return true, nil
}
modulePath, err := getModulePath(moduleName)
if err != nil {
return false, fmt.Errorf("couldn't find module path for %s, error: %v", moduleName, err)
}
if modulePath == "" {
return false, nil
}
log.Infof("trying to load %s module", moduleName)
err = loadModuleWithDependencies(moduleName, modulePath)
if err != nil {
return false, fmt.Errorf("couldn't load %s module, error: %v", moduleName, err)
}
return true, nil
}
func isModuleEnabled(name string) bool {
builtin, builtinErr := isBuiltinModule(name)
state, statusErr := moduleStatus(name)
return (builtinErr == nil && builtin) || (statusErr == nil && state >= loading)
}
func getModulePath(name string) (string, error) {
var foundPath string
skipRemainingDirs := false
err := filepath.WalkDir(
moduleRoot,
func(path string, info fs.DirEntry, err error) error {
if skipRemainingDirs {
return fs.SkipDir
}
if err != nil {
// skip broken files
return nil
}
if !info.Type().IsRegular() {
return nil
}
nameFromPath := pathToName(path)
if nameFromPath == name {
foundPath = path
skipRemainingDirs = true
}
return nil
})
if err != nil {
return "", err
}
return foundPath, nil
}
func pathToName(s string) string {
s = filepath.Base(s)
for ext := filepath.Ext(s); ext != ""; ext = filepath.Ext(s) {
s = strings.TrimSuffix(s, ext)
}
return cleanName(s)
}
func cleanName(s string) string {
return strings.ReplaceAll(strings.TrimSpace(s), "-", "_")
}
func isBuiltinModule(name string) (bool, error) {
f, err := os.Open(filepath.Join(moduleRoot, "/modules.builtin"))
if err != nil {
return false, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing modules.builtin file, %v", err)
}
}()
var found bool
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if pathToName(line) == name {
found = true
break
}
}
if err := scanner.Err(); err != nil {
return false, err
}
return found, nil
}
// /proc/modules
// name | memory size | reference count | references | state: <Live|Loading|Unloading>
// macvlan 28672 1 macvtap, Live 0x0000000000000000
func moduleStatus(name string) (status, error) {
state := unknown
f, err := os.Open("/proc/modules")
if err != nil {
return state, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing /proc/modules file, %v", err)
}
}()
state = unloaded
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if fields[0] == name {
if fields[2] != "0" {
state = inuse
break
}
switch fields[4] {
case "Live":
state = live
case "Loading":
state = loading
case "Unloading":
state = unloading
}
break
}
}
if err := scanner.Err(); err != nil {
return state, err
}
return state, nil
}
func loadModuleWithDependencies(name, path string) error {
deps, err := getModuleDependencies(name)
if err != nil {
return fmt.Errorf("couldn't load list of module %s dependecies", name)
}
for _, dep := range deps {
err = loadModule(dep.name, dep.path)
if err != nil {
return fmt.Errorf("couldn't load dependecy module %s for %s", dep.name, name)
}
}
return loadModule(name, path)
}
func loadModule(name, path string) error {
state, err := moduleStatus(name)
if err != nil {
return err
}
if state >= loading {
return nil
}
f, err := os.Open(path)
if err != nil {
return err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing %s file, %v", path, err)
}
}()
// first try finit_module(2), then init_module(2)
err = unix.FinitModule(int(f.Fd()), "", 0)
if errors.Is(err, unix.ENOSYS) {
buf, err := ioutil.ReadAll(f)
if err != nil {
return err
}
return unix.InitModule(buf, "")
}
return err
}
// getModuleDependencies returns a module dependencies
func getModuleDependencies(name string) ([]module, error) {
f, err := os.Open(filepath.Join(moduleRoot, "/modules.dep"))
if err != nil {
return nil, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing modules.dep file, %v", err)
}
}()
var deps []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if pathToName(strings.TrimSuffix(fields[0], ":")) == name {
deps = fields
break
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
if len(deps) == 0 {
return nil, ErrModuleNotFound
}
deps[0] = strings.TrimSuffix(deps[0], ":")
var modules []module
for _, v := range deps {
if pathToName(v) != name {
modules = append(modules, module{
name: pathToName(v),
path: filepath.Join(moduleRoot, v),
})
}
}
return modules, nil
}

221
iface/module_linux_test.go Normal file
View File

@@ -0,0 +1,221 @@
package iface
import (
"bufio"
"bytes"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)
func TestGetModuleDependencies(t *testing.T) {
testCases := []struct {
name string
module string
expected []module
}{
{
name: "Get Single Dependency",
module: "bar",
expected: []module{
{name: "foo", path: "kernel/a/foo.ko"},
},
},
{
name: "Get Multiple Dependencies",
module: "baz",
expected: []module{
{name: "foo", path: "kernel/a/foo.ko"},
{name: "bar", path: "kernel/a/bar.ko"},
},
},
{
name: "Get No Dependencies",
module: "foo",
expected: []module{},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
modules, err := getModuleDependencies(testCase.module)
require.NoError(t, err)
expected := testCase.expected
for i := range expected {
expected[i].path = moduleRoot + "/" + expected[i].path
}
require.ElementsMatchf(t, modules, expected, "returned modules should match")
})
}
}
func TestIsBuiltinModule(t *testing.T) {
testCases := []struct {
name string
module string
expected bool
}{
{
name: "Built In Should Return True",
module: "foo_bi",
expected: true,
},
{
name: "Not Built In Should Return False",
module: "not_built_in",
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
isBuiltIn, err := isBuiltinModule(testCase.module)
require.NoError(t, err)
require.Equal(t, testCase.expected, isBuiltIn)
})
}
}
func TestModuleStatus(t *testing.T) {
random, err := getRandomLoadedModule(t)
if err != nil {
t.Fatal("should be able to get random module")
}
testCases := []struct {
name string
module string
shouldBeLoaded bool
}{
{
name: "Should Return Module Loading Or Greater Status",
module: random,
shouldBeLoaded: true,
},
{
name: "Should Return Module Unloaded Or Lower Status",
module: "not_loaded_module",
shouldBeLoaded: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
state, err := moduleStatus(testCase.module)
require.NoError(t, err)
if testCase.shouldBeLoaded {
require.GreaterOrEqual(t, loading, state, "moduleStatus for %s should return state loading", testCase.module)
} else {
require.Less(t, state, loading, "module should return state unloading or lower")
}
})
}
}
func resetGlobals() {
moduleLibDir = defaultModuleDir
moduleRoot = getModuleRoot()
}
func createFiles(t *testing.T) (string, []module) {
writeFile := func(path, text string) {
if err := ioutil.WriteFile(path, []byte(text), 0644); err != nil {
t.Fatal(err)
}
}
var u unix.Utsname
if err := unix.Uname(&u); err != nil {
t.Fatal(err)
}
moduleLibDir = t.TempDir()
moduleRoot = getModuleRoot()
if err := os.Mkdir(moduleRoot, 0755); err != nil {
t.Fatal(err)
}
text := "kernel/a/foo.ko:\n"
text += "kernel/a/bar.ko: kernel/a/foo.ko\n"
text += "kernel/a/baz.ko: kernel/a/bar.ko kernel/a/foo.ko\n"
writeFile(filepath.Join(moduleRoot, "/modules.dep"), text)
text = "kernel/a/foo_bi.ko\n"
text += "kernel/a/bar-bi.ko.gz\n"
writeFile(filepath.Join(moduleRoot, "/modules.builtin"), text)
modules := []module{
{name: "foo", path: "kernel/a/foo.ko"},
{name: "bar", path: "kernel/a/bar.ko"},
{name: "baz", path: "kernel/a/baz.ko"},
}
return moduleLibDir, modules
}
func getRandomLoadedModule(t *testing.T) (string, error) {
f, err := os.Open("/proc/modules")
if err != nil {
return "", err
}
defer func() {
err := f.Close()
if err != nil {
t.Logf("failed closing /proc/modules file, %v", err)
}
}()
lines, err := lineCounter(f)
if err != nil {
return "", err
}
counter := 1
midLine := lines / 2
modName := ""
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if counter == midLine {
if fields[4] == "Unloading" {
continue
}
modName = fields[0]
break
}
counter++
}
if scanner.Err() != nil {
return "", scanner.Err()
}
return modName, nil
}
func lineCounter(r io.Reader) (int, error) {
buf := make([]byte, 32*1024)
count := 0
lineSep := []byte{'\n'}
for {
c, err := r.Read(buf)
count += bytes.Count(buf[:c], lineSep)
switch {
case err == io.EOF:
return count, nil
case err != nil:
return count, err
}
}
}

View File

@@ -109,7 +109,9 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return err
}
stream, err := c.connectToStream(*serverPubKey)
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -145,7 +147,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil
}
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
myPrivateKey := c.key
@@ -156,9 +158,12 @@ func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.Management
log.Errorf("failed encrypting message: %s", err)
return nil, err
}
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
return c.realClient.Sync(c.ctx, syncReq)
sync, err := c.realClient.Sync(ctx, syncReq)
if err != nil {
return nil, err
}
return sync, nil
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {

View File

@@ -15,6 +15,7 @@ import (
"google.golang.org/grpc/status"
"math/rand"
"reflect"
"regexp"
"strings"
"sync"
"time"
@@ -39,6 +40,7 @@ type AccountManager interface {
autoGroups []string,
) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
@@ -107,10 +109,11 @@ type Account struct {
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
}
func (a *Account) Copy() *Account {
@@ -300,19 +303,25 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err
return nil
}
func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo {
return &UserInfo{
ID: local.Id,
Email: queried.Email,
Name: queried.Name,
Role: string(local.Role),
}
}
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
}
func (am *DefaultAccountManager) lookupUserInCache(user *User, accountID string) (*idp.UserData, error) {
userData, err := am.lookupCache(map[string]*User{user.Id: user}, accountID)
if err != nil {
return nil, err
}
for _, datum := range userData {
if datum.ID == user.Id {
return datum, nil
}
}
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", user.Id)
}
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, accountID string) ([]*idp.UserData, error) {
data, err := am.cacheManager.Get(am.ctx, accountID)
if err != nil {
@@ -352,46 +361,6 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, acco
return userData, err
}
// GetUsersFromAccount performs a batched request for users from IDP by account id
func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) {
account, err := am.GetAccountById(accountID)
if err != nil {
return nil, err
}
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
queriedUsers, err = am.lookupCache(account.Users, accountID)
if err != nil {
return nil, err
}
}
userInfo := make([]*UserInfo, 0)
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
for _, user := range account.Users {
userInfo = append(userInfo, &UserInfo{
ID: user.Id,
Email: "",
Name: "",
Role: string(user.Role),
})
}
return userInfo, nil
}
for _, queriedUser := range queriedUsers {
if localUser, contains := account.Users[queriedUser.ID]; contains {
userInfo = append(userInfo, mergeLocalAndQueryUser(*queriedUser, *localUser))
log.Debugf("Merged userinfo to send back; %v", userInfo)
}
}
return userInfo, nil
}
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(
account *Account,
@@ -512,7 +481,7 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
) (*Account, error) {
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory {
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId)
@@ -552,6 +521,11 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
}
}
func isDomainValid(domain string) bool {
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
return re.Match([]byte(domain))
}
// AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
am.mux.Lock()

View File

@@ -140,6 +140,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
expectedMSG string
expectedUserRole UserRole
expectedDomainCategory string
expectedDomain string
expectedPrimaryDomainStatus bool
expectedCreatedBy string
expectedUsers []string
@@ -168,6 +169,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomainCategory: "",
expectedDomain: publicDomain,
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pub-domain-user",
expectedUsers: []string{"pub-domain-user"},
@@ -188,6 +190,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomain: unknownDomain,
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "unknown-domain-user",
@@ -205,6 +208,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: "pvt-domain-user",
@@ -227,6 +231,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleUser,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
@@ -244,6 +249,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleAdmin,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
@@ -262,12 +268,32 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleAdmin,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
testCase7 := test{
name: "User With Private Category And Empty Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomain: "",
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -294,6 +320,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "expected user role should match")
require.EqualValues(t, testCase.expectedDomainCategory, account.DomainCategory, "expected account domain category should match")
require.EqualValues(t, testCase.expectedPrimaryDomainStatus, account.IsDomainPrimaryAccount, "expected account primary status should match")
require.EqualValues(t, testCase.expectedDomain, account.Domain, "expected account domain should match")
})
}
}

View File

@@ -17,6 +17,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
gRPCPeer "google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
@@ -79,7 +80,10 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
log.Debugf("Sync request from peer %s", req.WgPubKey)
p, ok := gRPCPeer.FromContext(srv.Context())
if ok {
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
}
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
@@ -255,7 +259,10 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.Debugf("Login request from peer %s", req.WgPubKey)
p, ok := gRPCPeer.FromContext(ctx)
if ok {
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
}
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {

View File

@@ -31,13 +31,33 @@ components:
description: User's name from idp provider
type: string
role:
description: User's Netbird account role
description: User's NetBird account role
type: string
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
items:
type: string
required:
- id
- email
- name
- role
- auto_groups
UserRequest:
type: object
properties:
role:
description: User's NetBird account role
type: string
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
items:
type: string
required:
- role
- auto_groups
PeerMinimum:
type: object
properties:
@@ -76,20 +96,18 @@ components:
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
activated_by:
description: Provides information of who activated the Peer. User or Setup Key
type: object
properties:
type:
type: string
value:
type: string
required:
- type
- value
ssh_enabled:
description: Indicates whether SSH server is enabled on this peer
type: boolean
user_id:
description: User ID of the user that enrolled this peer
type: string
hostname:
description: Hostname of the machine
type: string
ui_version:
description: Peer's desktop UI version
type: string
required:
- ip
- connected
@@ -97,8 +115,8 @@ components:
- os
- version
- groups
- activated_by
- ssh_enabled
- hostname
SetupKey:
type: object
properties:
@@ -409,6 +427,40 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/{id}:
put:
summary: Update information about a User
tags: [ Users]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The User ID
requestBody:
description: User update
content:
'application/json':
schema:
$ref: '#/components/schemas/UserRequest'
responses:
'200':
description: A User object
content:
application/json:
schema:
$ref: '#/components/schemas/User'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers:
get:
summary: Returns a list of all peers

View File

@@ -125,18 +125,15 @@ type PatchMinimumOp string
// Peer defines model for Peer.
type Peer struct {
// Provides information of who activated the Peer. User or Setup Key
ActivatedBy struct {
Type string `json:"type"`
Value string `json:"value"`
} `json:"activated_by"`
// Peer to Management connection status
Connected bool `json:"connected"`
// Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
// Hostname of the machine
Hostname string `json:"hostname"`
// Peer ID
Id string `json:"id"`
@@ -155,6 +152,12 @@ type Peer struct {
// Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"`
// Peer's desktop UI version
UiVersion *string `json:"ui_version,omitempty"`
// User ID of the user that enrolled this peer
UserId *string `json:"user_id,omitempty"`
// Peer's daemon or cli version
Version string `json:"version"`
}
@@ -356,6 +359,9 @@ type SetupKeyRequest struct {
// User defines model for User.
type User struct {
// Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// User's email address
Email string `json:"email"`
@@ -365,7 +371,16 @@ type User struct {
// User's name from idp provider
Name string `json:"name"`
// User's Netbird account role
// User's NetBird account role
Role string `json:"role"`
}
// UserRequest defines model for UserRequest.
type UserRequest struct {
// Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// User's NetBird account role
Role string `json:"role"`
}
@@ -442,6 +457,9 @@ type PostApiSetupKeysJSONBody = SetupKeyRequest
// PutApiSetupKeysIdJSONBody defines parameters for PutApiSetupKeysId.
type PutApiSetupKeysIdJSONBody = SetupKeyRequest
// PutApiUsersIdJSONBody defines parameters for PutApiUsersId.
type PutApiUsersIdJSONBody = UserRequest
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
@@ -477,3 +495,6 @@ type PostApiSetupKeysJSONRequestBody = PostApiSetupKeysJSONBody
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
type PutApiSetupKeysIdJSONRequestBody = PutApiSetupKeysIdJSONBody
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
type PutApiUsersIdJSONRequestBody = PutApiUsersIdJSONBody

View File

@@ -39,6 +39,7 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")

View File

@@ -11,7 +11,7 @@ import (
"net/http"
)
//Peers is a handler that returns peers of the account
// Peers is a handler that returns peers of the account
type Peers struct {
accountManager server.AccountManager
authAudience string
@@ -144,5 +144,8 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: &peer.UserID,
UiVersion: &peer.Meta.UIVersion,
}
}

View File

@@ -1,7 +1,12 @@
package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
log "github.com/sirupsen/logrus"
@@ -24,6 +29,59 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
}
}
// UpdateUser is a PUT requests to update User data
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "", http.StatusBadRequest)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
vars := mux.Vars(r)
userID := vars["id"]
if len(userID) == 0 {
http.Error(w, "invalid user ID", http.StatusBadRequest)
return
}
req := &api.PutApiUsersIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
http.Error(w, "invalid user role", http.StatusBadRequest)
return
}
newUser, err := h.accountManager.SaveUser(account.Id, &server.User{
Id: userID,
Role: userRole,
AutoGroups: req.AutoGroups,
})
if err != nil {
if e, ok := status.FromError(err); ok {
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
default:
http.Error(w, "failed to update user", http.StatusInternalServerError)
}
}
return
}
writeJSONObject(w, toUserResponse(newUser))
}
// GetUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager.
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
@@ -52,10 +110,17 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
}
func toUserResponse(user *server.UserInfo) *api.User {
autoGroups := user.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
return &api.User{
Id: user.ID,
Name: user.Name,
Email: user.Email,
Role: user.Role,
Id: user.ID,
Name: user.Name,
Email: user.Email,
Role: user.Role,
AutoGroups: autoGroups,
}
}

View File

@@ -52,6 +52,7 @@ type MockAccountManager struct {
ListRoutesFunc func(accountID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@@ -421,3 +422,11 @@ func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKe
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
}
// SaveUser mocks SaveUser of the AccountManager interface
func (am *MockAccountManager) SaveUser(accountID string, user *server.User) (*server.UserInfo, error) {
if am.SaveUserFunc != nil {
return am.SaveUserFunc(accountID, user)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
}

View File

@@ -294,6 +294,8 @@ func (am *DefaultAccountManager) AddPeer(
var account *Account
var err error
var sk *SetupKey
// auto-assign groups that are coming with a SetupKey or a User
var groupsToAdd []string
if len(upperKey) != 0 {
account, err = am.Store.GetAccountBySetupKey(upperKey)
if err != nil {
@@ -321,11 +323,20 @@ func (am *DefaultAccountManager) AddPeer(
)
}
groupsToAdd = sk.AutoGroups
} else if len(userID) != 0 {
account, err = am.Store.GetUserAccount(userID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
}
user, ok := account.Users[userID]
if !ok {
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
}
groupsToAdd = user.AutoGroups
} else {
// Empty setup key and jwt fail
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
@@ -361,6 +372,14 @@ func (am *DefaultAccountManager) AddPeer(
}
group.Peers = append(group.Peers, newPeer.Key)
if len(groupsToAdd) > 0 {
for _, s := range groupsToAdd {
if g, ok := account.Groups[s]; ok && g.Name != "All" {
g.Peers = append(g.Peers, newPeer.Key)
}
}
}
account.Peers[newPeer.Key] = newPeer
if len(upperKey) != 0 {
account.SetupKeys[sk.Key] = sk.IncrementUsage()

View File

@@ -80,6 +80,8 @@ type SetupKey struct {
// Copy copies SetupKey to a new object
func (key *SetupKey) Copy() *SetupKey {
autoGroups := make([]string, 0)
autoGroups = append(autoGroups, key.AutoGroups...)
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
@@ -94,7 +96,7 @@ func (key *SetupKey) Copy() *SetupKey {
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
AutoGroups: key.AutoGroups,
AutoGroups: autoGroups,
}
}

View File

@@ -2,19 +2,32 @@ package server
import (
"fmt"
"strings"
"github.com/netbirdio/netbird/management/server/idp"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strings"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
const (
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown"
)
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
func StrRoleToUserRole(strRole string) UserRole {
switch strings.ToLower(strRole) {
case "admin":
return UserRoleAdmin
case "user":
return UserRoleUser
default:
return UserRoleUnknown
}
}
// UserRole is the role of the User
type UserRole string
@@ -22,20 +35,56 @@ type UserRole string
type User struct {
Id string
Role UserRole
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string
}
// toUserInfo converts a User object to a UserInfo object.
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
autoGroups := u.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
if userData == nil {
return &UserInfo{
ID: u.Id,
Email: "",
Name: "",
Role: string(u.Role),
AutoGroups: u.AutoGroups,
}, nil
}
if userData.ID != u.Id {
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
}
return &UserInfo{
ID: u.Id,
Email: userData.Email,
Name: userData.Name,
Role: string(u.Role),
AutoGroups: autoGroups,
}, nil
}
// Copy the user
func (u *User) Copy() *User {
autoGroups := []string{}
autoGroups = append(autoGroups, u.AutoGroups...)
return &User{
Id: u.Id,
Role: u.Role,
Id: u.Id,
Role: u.Role,
AutoGroups: autoGroups,
}
}
// NewUser creates a new user
func NewUser(id string, role UserRole) *User {
return &User{
Id: id,
Role: role,
Id: id,
Role: role,
AutoGroups: []string{},
}
}
@@ -49,6 +98,55 @@ func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin)
}
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
am.mux.Lock()
defer am.mux.Unlock()
if update == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return nil,
status.Errorf(codes.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
}
oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(codes.NotFound, "update not found")
}
// only auto groups, revoked status, and name can be updated for now
newUser := oldUser.Copy()
newUser.AutoGroups = update.AutoGroups
newUser.Role = update.Role
account.Users[newUser.Id] = newUser
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
if !isNil(am.idpManager) {
userData, err := am.lookupUserInCache(newUser, accountID)
if err != nil {
return nil, err
}
return newUser.toUserInfo(userData)
}
return newUser.toUserInfo(nil)
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
am.mux.Lock()
@@ -108,3 +206,46 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim
return user.Role == UserRoleAdmin, nil
}
// GetUsersFromAccount performs a batched request for users from IDP by account ID
func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) {
account, err := am.GetAccountById(accountID)
if err != nil {
return nil, err
}
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
queriedUsers, err = am.lookupCache(account.Users, accountID)
if err != nil {
return nil, err
}
}
userInfos := make([]*UserInfo, 0)
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
for _, user := range account.Users {
info, err := user.toUserInfo(nil)
if err != nil {
return nil, err
}
userInfos = append(userInfos, info)
}
return userInfos, nil
}
for _, queriedUser := range queriedUsers {
if localUser, contains := account.Users[queriedUser.ID]; contains {
info, err := localUser.toUserInfo(queriedUser)
if err != nil {
return nil, err
}
userInfos = append(userInfos, info)
}
}
return userInfos, nil
}

View File

@@ -87,7 +87,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
}, nil
}
//defaultBackoff is a basic backoff mechanism for general issues
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
@@ -121,9 +121,11 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
}
// connect to Signal stream identifying ourselves with a public Wireguard key
// connect to Signal stream identifying ourselves with a public WireGuard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
stream, err := c.connect(c.key.PublicKey().String())
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connect(ctx, c.key.PublicKey().String())
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
return err
@@ -180,15 +182,13 @@ func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
return c.connectedCh
}
func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil
// add key fingerprint to the request header to be identified on the server side
md := metadata.New(map[string]string{proto.HeaderId: key})
ctx := metadata.NewOutgoingContext(c.ctx, md)
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
metaCtx := metadata.NewOutgoingContext(ctx, md)
stream, err := c.realClient.ConnectStream(metaCtx, grpc.WaitForReady(true))
c.stream = stream
if err != nil {
return nil, err