mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-26 12:16:39 +00:00
Compare commits
19 Commits
v0.8.11
...
feature/ap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5d197cd5f9 | ||
|
|
6bee984b46 | ||
|
|
2ee7d69f80 | ||
|
|
af69a48745 | ||
|
|
68ff97ba84 | ||
|
|
c5705803a5 | ||
|
|
7e1ae448e0 | ||
|
|
518a2561a2 | ||
|
|
c75ffd0f4b | ||
|
|
e4ad6174ca | ||
|
|
6de313070a | ||
|
|
cd7d1a80c9 | ||
|
|
be7d829858 | ||
|
|
ed1872560f | ||
|
|
de898899a4 | ||
|
|
b63ec71aed | ||
|
|
1012172f04 | ||
|
|
788bb00ef1 | ||
|
|
4e5ee70b3d |
52
.github/workflows/golang-test-linux.yml
vendored
52
.github/workflows/golang-test-linux.yml
vendored
@@ -33,3 +33,55 @@ jobs:
|
||||
|
||||
- name: Test
|
||||
run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.18.x
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libappindicator3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: go test -c -o iface-testing.bin ./iface/...
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||
|
||||
- name: Generate Engine Test bin
|
||||
run: go test -c -o engine-testing.bin ./client/internal/*.go
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||
|
||||
- run: chmod +x *testing.bin
|
||||
|
||||
- name: Run Iface tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin
|
||||
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin
|
||||
|
||||
- name: Run Engine tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin
|
||||
28
README.md
28
README.md
@@ -16,7 +16,7 @@
|
||||
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&utm_medium=referral&utm_content=netbirdio/netbird&utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
|
||||
<img src="https://img.shields.io/badge/slack-@wiretrustee-red.svg?logo=slack"/>
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
@@ -43,20 +43,20 @@ It requires zero configuration effort leaving behind the hassle of opening ports
|
||||
NetBird creates an overlay peer-to-peer network connecting machines automatically regardless of their location (home, office, datacenter, container, cloud or edge environments) unifying virtual private network management experience.
|
||||
|
||||
**Key features:**
|
||||
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
|
||||
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
|
||||
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
|
||||
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
|
||||
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
|
||||
- \[x] Multiuser support - sharing network between multiple users.
|
||||
- \[x] SSO and MFA support.
|
||||
- \[x] Multicloud and hybrid-cloud support.
|
||||
- \[x] Kernel WireGuard usage when possible.
|
||||
- \[x] Access Controls - groups & rules.
|
||||
- \[x] Remote SSH access without managing SSH keys.
|
||||
|
||||
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
|
||||
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
|
||||
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
|
||||
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
|
||||
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
|
||||
- \[x] Multiuser support - sharing network between multiple users.
|
||||
- \[x] SSO and MFA support.
|
||||
- \[x] Multicloud and hybrid-cloud support.
|
||||
- \[x] Kernel WireGuard usage when possible.
|
||||
- \[x] Access Controls - groups & rules.
|
||||
- \[x] Remote SSH access without managing SSH keys.
|
||||
- \[x] Network Routes.
|
||||
|
||||
**Coming soon:**
|
||||
- \[ ] Router nodes
|
||||
- \[ ] Private DNS.
|
||||
- \[ ] Mobile clients.
|
||||
- \[ ] Network Activity Monitoring.
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
ipsFilter []string
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
@@ -73,7 +75,7 @@ var statusCmd = &cobra.Command{
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
fullStatus := fromProtoFullStatus(pbFullStatus)
|
||||
|
||||
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion()))
|
||||
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion(), ipv4Flag))
|
||||
|
||||
return nil
|
||||
},
|
||||
@@ -82,8 +84,9 @@ var statusCmd = &cobra.Command{
|
||||
func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g. --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g. --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
@@ -142,7 +145,19 @@ func fromProtoFullStatus(pbFullStatus *proto.FullStatus) nbStatus.FullStatus {
|
||||
return fullStatus
|
||||
}
|
||||
|
||||
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string) string {
|
||||
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string, flag bool) string {
|
||||
|
||||
interfaceIP := fullStatus.LocalPeerState.IP
|
||||
|
||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if ipv4Flag {
|
||||
return fmt.Sprintf("%s\n", ip)
|
||||
}
|
||||
|
||||
var (
|
||||
managementStatusURL = ""
|
||||
signalStatusURL = ""
|
||||
@@ -164,8 +179,6 @@ func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonSta
|
||||
signalConnString = "Connected"
|
||||
}
|
||||
|
||||
interfaceIP := fullStatus.LocalPeerState.IP
|
||||
|
||||
if fullStatus.LocalPeerState.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if fullStatus.LocalPeerState.IP == "" {
|
||||
|
||||
@@ -37,6 +37,7 @@ type Config struct {
|
||||
ManagementURL *url.URL
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
IFaceBlackList []string
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
@@ -49,7 +50,13 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config := &Config{SSHKey: string(pem), PrivateKey: wgKey, WgIface: iface.WgInterfaceDefault, IFaceBlackList: []string{}}
|
||||
config := &Config{
|
||||
SSHKey: string(pem),
|
||||
PrivateKey: wgKey,
|
||||
WgIface: iface.WgInterfaceDefault,
|
||||
WgPort: iface.DefaultWgPort,
|
||||
IFaceBlackList: []string{},
|
||||
}
|
||||
if managementURL != "" {
|
||||
URL, err := ParseURL("Management URL", managementURL)
|
||||
if err != nil {
|
||||
@@ -72,8 +79,8 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
|
||||
config.AdminURL = newURL
|
||||
}
|
||||
|
||||
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
|
||||
"Tailscale", "tailscale"}
|
||||
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "vet"}
|
||||
|
||||
err = util.WriteJson(configPath, config)
|
||||
if err != nil {
|
||||
@@ -150,6 +157,11 @@ func ReadConfig(managementURL, adminURL, configPath string, preSharedKey *string
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.WgPort == 0 {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if refresh {
|
||||
// since we have new management URL, we need to update config file
|
||||
if err := util.WriteJson(configPath, config); err != nil {
|
||||
@@ -251,7 +263,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
|
||||
}
|
||||
}
|
||||
|
||||
return DeviceAuthorizationFlow{
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: ProviderConfig{
|
||||
@@ -262,5 +274,32 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isProviderConfigValid(config ProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.ClientSecret == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client Secret")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
localPeerState := nbStatus.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: iface.WireguardModExists(),
|
||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||
}
|
||||
|
||||
statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
@@ -188,7 +188,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
WgPrivateKey: key,
|
||||
WgPort: iface.DefaultWgPort,
|
||||
WgPort: config.WgPort,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
@@ -99,6 +101,8 @@ type Engine struct {
|
||||
sshServer nbssh.Server
|
||||
|
||||
statusRecorder *nbstatus.Status
|
||||
|
||||
routeManager routemanager.Manager
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -182,6 +186,10 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
@@ -232,6 +240,8 @@ func (e *Engine) Start() error {
|
||||
return err
|
||||
}
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
|
||||
@@ -388,7 +398,8 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp
|
||||
return nil
|
||||
}
|
||||
|
||||
func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
|
||||
var t sProto.Body_Type
|
||||
if isAnswer {
|
||||
t = sProto.Body_ANSWER
|
||||
@@ -396,9 +407,9 @@ func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.K
|
||||
t = sProto.Body_OFFER
|
||||
}
|
||||
|
||||
msg, err := signal.MarshalCredential(myKey, remoteKey, &signal.Credential{
|
||||
UFrag: uFrag,
|
||||
Pwd: pwd,
|
||||
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
|
||||
UFrag: offerAnswer.IceCredentials.UFrag,
|
||||
Pwd: offerAnswer.IceCredentials.Pwd,
|
||||
}, t)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -618,11 +629,37 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update routes, err: %v", err)
|
||||
}
|
||||
|
||||
e.networkSerial = serial
|
||||
return nil
|
||||
}
|
||||
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
||||
convertedRoute := &route.Route{
|
||||
ID: protoRoute.ID,
|
||||
Network: prefix,
|
||||
NetID: protoRoute.NetID,
|
||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||
Peer: protoRoute.Peer,
|
||||
Metric: int(protoRoute.Metric),
|
||||
Masquerade: protoRoute.Masquerade,
|
||||
}
|
||||
routes = append(routes, convertedRoute)
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
for _, p := range peersUpdate {
|
||||
@@ -726,6 +763,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
||||
UDPMux: e.udpMux,
|
||||
UDPMuxSrflx: e.udpMuxSrflx,
|
||||
ProxyConfig: proxyConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder)
|
||||
@@ -738,16 +776,16 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signalOffer := func(uFrag string, pwd string) error {
|
||||
return signalAuth(uFrag, pwd, e.config.WgPrivateKey, wgPubKey, e.signal, false)
|
||||
signalOffer := func(offerAnswer peer.OfferAnswer) error {
|
||||
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, false)
|
||||
}
|
||||
|
||||
signalCandidate := func(candidate ice.Candidate) error {
|
||||
return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal)
|
||||
}
|
||||
|
||||
signalAnswer := func(uFrag string, pwd string) error {
|
||||
return signalAuth(uFrag, pwd, e.config.WgPrivateKey, wgPubKey, e.signal, true)
|
||||
signalAnswer := func(offerAnswer peer.OfferAnswer) error {
|
||||
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, true)
|
||||
}
|
||||
|
||||
peerConn.SetSignalCandidate(signalCandidate)
|
||||
@@ -776,18 +814,26 @@ func (e *Engine) receiveSignalEvents() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.OnRemoteOffer(peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
Pwd: remoteCred.Pwd,
|
||||
conn.OnRemoteOffer(peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
Pwd: remoteCred.Pwd,
|
||||
},
|
||||
WgListenPort: int(msg.GetBody().GetWgListenPort()),
|
||||
Version: msg.GetBody().GetNetBirdVersion(),
|
||||
})
|
||||
case sProto.Body_ANSWER:
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.OnRemoteAnswer(peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
Pwd: remoteCred.Pwd,
|
||||
conn.OnRemoteAnswer(peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
Pwd: remoteCred.Pwd,
|
||||
},
|
||||
WgListenPort: int(msg.GetBody().GetWgListenPort()),
|
||||
Version: msg.GetBody().GetNetBirdVersion(),
|
||||
})
|
||||
case sProto.Body_CANDIDATE:
|
||||
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
|
||||
|
||||
@@ -3,11 +3,14 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -196,6 +199,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgPort: 33100,
|
||||
}, nbstatus.NewRecorder())
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -426,6 +430,142 @@ func TestEngine_Sync(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputErr error
|
||||
networkMap *mgmtProto.NetworkMap
|
||||
expectedLen int
|
||||
expectedRoutes []*route.Route
|
||||
expectedSerial uint64
|
||||
}{
|
||||
{
|
||||
name: "Routes Update Should Be Passed To Manager",
|
||||
networkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 1,
|
||||
PeerConfig: nil,
|
||||
RemotePeersIsEmpty: false,
|
||||
Routes: []*mgmtProto.Route{
|
||||
{
|
||||
ID: "a",
|
||||
Network: "192.168.0.0/24",
|
||||
NetID: "n1",
|
||||
Peer: "p1",
|
||||
NetworkType: 1,
|
||||
Masquerade: false,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
Network: "192.168.1.0/24",
|
||||
NetID: "n2",
|
||||
Peer: "p1",
|
||||
NetworkType: 1,
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 2,
|
||||
expectedRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
NetID: "n1",
|
||||
Peer: "p1",
|
||||
NetworkType: 1,
|
||||
Masquerade: false,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
NetID: "n2",
|
||||
Peer: "p1",
|
||||
NetworkType: 1,
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
expectedSerial: 1,
|
||||
},
|
||||
{
|
||||
name: "Empty Routes Update Should Be Passed",
|
||||
networkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 1,
|
||||
PeerConfig: nil,
|
||||
RemotePeersIsEmpty: false,
|
||||
Routes: nil,
|
||||
},
|
||||
expectedLen: 0,
|
||||
expectedRoutes: []*route.Route{},
|
||||
expectedSerial: 1,
|
||||
},
|
||||
{
|
||||
name: "Error Shouldn't Break Engine",
|
||||
inputErr: fmt.Errorf("mocking error"),
|
||||
networkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 1,
|
||||
PeerConfig: nil,
|
||||
RemotePeersIsEmpty: false,
|
||||
Routes: nil,
|
||||
},
|
||||
expectedLen: 0,
|
||||
expectedRoutes: []*route.Route{},
|
||||
expectedSerial: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
// test setup
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, nbstatus.NewRecorder())
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
inputRoutes []*route.Route
|
||||
}{}
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
input.inputSerial = updateSerial
|
||||
input.inputRoutes = newRoutes
|
||||
return testCase.inputErr
|
||||
},
|
||||
}
|
||||
|
||||
engine.routeManager = mockRouteManager
|
||||
|
||||
defer func() {
|
||||
exitErr := engine.Stop()
|
||||
if exitErr != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.updateNetworkMap(testCase.networkMap)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
|
||||
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
|
||||
@@ -36,7 +36,10 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package peer
|
||||
import (
|
||||
"context"
|
||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"net"
|
||||
@@ -36,6 +37,20 @@ type ConnConfig struct {
|
||||
|
||||
UDPMux ice.UDPMux
|
||||
UDPMuxSrflx ice.UniversalUDPMux
|
||||
|
||||
LocalWgPort int
|
||||
}
|
||||
|
||||
// OfferAnswer represents a session establishment offer or answer
|
||||
type OfferAnswer struct {
|
||||
IceCredentials IceCredentials
|
||||
// WgListenPort is a remote WireGuard listen port.
|
||||
// This field is used when establishing a direct WireGuard connection without any proxy.
|
||||
// We can set the remote peer's endpoint with this port.
|
||||
WgListenPort int
|
||||
|
||||
// Version of NetBird Agent
|
||||
Version string
|
||||
}
|
||||
|
||||
// IceCredentials ICE protocol credentials struct
|
||||
@@ -51,13 +66,13 @@ type Conn struct {
|
||||
// signalCandidate is a handler function to signal remote peer about local connection candidate
|
||||
signalCandidate func(candidate ice.Candidate) error
|
||||
// signalOffer is a handler function to signal remote peer our connection offer (credentials)
|
||||
signalOffer func(uFrag string, pwd string) error
|
||||
signalAnswer func(uFrag string, pwd string) error
|
||||
signalOffer func(OfferAnswer) error
|
||||
signalAnswer func(OfferAnswer) error
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan IceCredentials
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
remoteAnswerCh chan IceCredentials
|
||||
remoteAnswerCh chan OfferAnswer
|
||||
closeCh chan struct{}
|
||||
ctx context.Context
|
||||
notifyDisconnected context.CancelFunc
|
||||
@@ -88,8 +103,8 @@ func NewConn(config ConnConfig, statusRecorder *nbStatus.Status) (*Conn, error)
|
||||
mu: sync.Mutex{},
|
||||
status: StatusDisconnected,
|
||||
closeCh: make(chan struct{}),
|
||||
remoteOffersCh: make(chan IceCredentials),
|
||||
remoteAnswerCh: make(chan IceCredentials),
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
statusRecorder: statusRecorder,
|
||||
}, nil
|
||||
}
|
||||
@@ -200,15 +215,15 @@ func (conn *Conn) Open() error {
|
||||
// Only continue once we got a connection confirmation from the remote peer.
|
||||
// The connection timeout could have happened before a confirmation received from the remote.
|
||||
// The connection could have also been closed externally (e.g. when we received an update from the management that peer shouldn't be connected)
|
||||
var remoteCredentials IceCredentials
|
||||
var remoteOfferAnswer OfferAnswer
|
||||
select {
|
||||
case remoteCredentials = <-conn.remoteOffersCh:
|
||||
case remoteOfferAnswer = <-conn.remoteOffersCh:
|
||||
// received confirmation from the remote peer -> ready to proceed
|
||||
err = conn.sendAnswer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case remoteCredentials = <-conn.remoteAnswerCh:
|
||||
case remoteOfferAnswer = <-conn.remoteAnswerCh:
|
||||
case <-time.After(conn.config.Timeout):
|
||||
return NewConnectionTimeoutError(conn.config.Key, conn.config.Timeout)
|
||||
case <-conn.closeCh:
|
||||
@@ -216,7 +231,8 @@ func (conn *Conn) Open() error {
|
||||
return NewConnectionClosedError(conn.config.Key)
|
||||
}
|
||||
|
||||
log.Debugf("received connection confirmation from peer %s", conn.config.Key)
|
||||
log.Debugf("received connection confirmation from peer %s running version %s and with remote WireGuard listen port %d",
|
||||
conn.config.Key, remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
|
||||
|
||||
// at this point we received offer/answer and we are ready to gather candidates
|
||||
conn.mu.Lock()
|
||||
@@ -245,16 +261,21 @@ func (conn *Conn) Open() error {
|
||||
isControlling := conn.config.LocalKey > conn.config.Key
|
||||
var remoteConn *ice.Conn
|
||||
if isControlling {
|
||||
remoteConn, err = conn.agent.Dial(conn.ctx, remoteCredentials.UFrag, remoteCredentials.Pwd)
|
||||
remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
remoteConn, err = conn.agent.Accept(conn.ctx, remoteCredentials.UFrag, remoteCredentials.Pwd)
|
||||
remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// dynamically set remote WireGuard port is other side specified a different one from the default one
|
||||
remoteWgPort := iface.DefaultWgPort
|
||||
if remoteOfferAnswer.WgListenPort != 0 {
|
||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||
}
|
||||
// the ice connection has been established successfully so we are ready to start the proxy
|
||||
err = conn.startProxy(remoteConn)
|
||||
err = conn.startProxy(remoteConn, remoteWgPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -319,7 +340,7 @@ func IsPublicIP(ip net.IP) bool {
|
||||
}
|
||||
|
||||
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) startProxy(remoteConn net.Conn) error {
|
||||
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -336,7 +357,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn) error {
|
||||
p = proxy.NewWireguardProxy(conn.config.ProxyConfig)
|
||||
peerState.Direct = false
|
||||
} else {
|
||||
p = proxy.NewNoProxy(conn.config.ProxyConfig)
|
||||
p = proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
||||
peerState.Direct = true
|
||||
}
|
||||
conn.proxy = p
|
||||
@@ -409,12 +430,12 @@ func (conn *Conn) cleanup() error {
|
||||
}
|
||||
|
||||
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
||||
func (conn *Conn) SetSignalOffer(handler func(uFrag string, pwd string) error) {
|
||||
func (conn *Conn) SetSignalOffer(handler func(offer OfferAnswer) error) {
|
||||
conn.signalOffer = handler
|
||||
}
|
||||
|
||||
// SetSignalAnswer sets a handler function to be triggered by Conn when a new connection answer has to be signalled to the remote peer
|
||||
func (conn *Conn) SetSignalAnswer(handler func(uFrag string, pwd string) error) {
|
||||
func (conn *Conn) SetSignalAnswer(handler func(answer OfferAnswer) error) {
|
||||
conn.signalAnswer = handler
|
||||
}
|
||||
|
||||
@@ -459,8 +480,12 @@ func (conn *Conn) sendAnswer() error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("sending asnwer to %s", conn.config.Key)
|
||||
err = conn.signalAnswer(localUFrag, localPwd)
|
||||
log.Debugf("sending answer to %s", conn.config.Key)
|
||||
err = conn.signalAnswer(OfferAnswer{
|
||||
IceCredentials: IceCredentials{localUFrag, localPwd},
|
||||
WgListenPort: conn.config.LocalWgPort,
|
||||
Version: system.NetbirdVersion(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -477,7 +502,11 @@ func (conn *Conn) sendOffer() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = conn.signalOffer(localUFrag, localPwd)
|
||||
err = conn.signalOffer(OfferAnswer{
|
||||
IceCredentials: IceCredentials{localUFrag, localPwd},
|
||||
WgListenPort: conn.config.LocalWgPort,
|
||||
Version: system.NetbirdVersion(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -518,11 +547,11 @@ func (conn *Conn) Status() ConnStatus {
|
||||
|
||||
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
// doesn't block, discards the message if connection wasn't ready
|
||||
func (conn *Conn) OnRemoteOffer(remoteAuth IceCredentials) bool {
|
||||
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
|
||||
log.Debugf("OnRemoteOffer from peer %s on status %s", conn.config.Key, conn.status.String())
|
||||
|
||||
select {
|
||||
case conn.remoteOffersCh <- remoteAuth:
|
||||
case conn.remoteOffersCh <- offer:
|
||||
return true
|
||||
default:
|
||||
log.Debugf("OnRemoteOffer skipping message from peer %s on status %s because is not ready", conn.config.Key, conn.status.String())
|
||||
@@ -533,11 +562,11 @@ func (conn *Conn) OnRemoteOffer(remoteAuth IceCredentials) bool {
|
||||
|
||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
// doesn't block, discards the message if connection wasn't ready
|
||||
func (conn *Conn) OnRemoteAnswer(remoteAuth IceCredentials) bool {
|
||||
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
log.Debugf("OnRemoteAnswer from peer %s on status %s", conn.config.Key, conn.status.String())
|
||||
|
||||
select {
|
||||
case conn.remoteAnswerCh <- remoteAuth:
|
||||
case conn.remoteAnswerCh <- answer:
|
||||
return true
|
||||
default:
|
||||
// connection might not be ready yet to receive so we ignore the message
|
||||
|
||||
@@ -18,6 +18,7 @@ var connConf = ConnConfig{
|
||||
InterfaceBlackList: nil,
|
||||
Timeout: time.Second,
|
||||
ProxyConfig: proxy.Config{},
|
||||
LocalWgPort: 51820,
|
||||
}
|
||||
|
||||
func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
@@ -59,9 +60,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
for {
|
||||
accepted := conn.OnRemoteOffer(IceCredentials{
|
||||
UFrag: "test",
|
||||
Pwd: "test",
|
||||
accepted := conn.OnRemoteOffer(OfferAnswer{
|
||||
IceCredentials: IceCredentials{
|
||||
UFrag: "test",
|
||||
Pwd: "test",
|
||||
},
|
||||
WgListenPort: 0,
|
||||
Version: "",
|
||||
})
|
||||
if accepted {
|
||||
wg.Done()
|
||||
@@ -89,9 +94,13 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
for {
|
||||
accepted := conn.OnRemoteAnswer(IceCredentials{
|
||||
UFrag: "test",
|
||||
Pwd: "test",
|
||||
accepted := conn.OnRemoteAnswer(OfferAnswer{
|
||||
IceCredentials: IceCredentials{
|
||||
UFrag: "test",
|
||||
Pwd: "test",
|
||||
},
|
||||
WgListenPort: 0,
|
||||
Version: "",
|
||||
})
|
||||
if accepted {
|
||||
wg.Done()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
)
|
||||
@@ -14,10 +13,14 @@ import (
|
||||
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
|
||||
type NoProxy struct {
|
||||
config Config
|
||||
// RemoteWgListenPort is a WireGuard port of a remote peer.
|
||||
// It is used instead of the hardcoded 51820 port.
|
||||
RemoteWgListenPort int
|
||||
}
|
||||
|
||||
func NewNoProxy(config Config) *NoProxy {
|
||||
return &NoProxy{config: config}
|
||||
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port
|
||||
func NewNoProxy(config Config, remoteWgPort int) *NoProxy {
|
||||
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort}
|
||||
}
|
||||
|
||||
func (p *NoProxy) Close() error {
|
||||
@@ -36,7 +39,7 @@ func (p *NoProxy) Start(remoteConn net.Conn) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr.Port = iface.DefaultWgPort
|
||||
addr.Port = p.RemoteWgListenPort
|
||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
addr, p.config.PreSharedKey)
|
||||
|
||||
|
||||
285
client/internal/routemanager/client.go
Normal file
285
client/internal/routemanager/client.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
direct bool
|
||||
}
|
||||
|
||||
type routesUpdate struct {
|
||||
updateSerial uint64
|
||||
routes []*route.Route
|
||||
}
|
||||
|
||||
type clientNetwork struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
statusRecorder *status.Status
|
||||
wgInterface *iface.WGIface
|
||||
routes map[string]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
routePeersNotifiers map[string]chan struct{}
|
||||
chosenRoute *route.Route
|
||||
network netip.Prefix
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *status.Status, network netip.Prefix) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
client := &clientNetwork{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
routes: make(map[string]*route.Route),
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
network: network,
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func getClientNetworkID(input *route.Route) string {
|
||||
return input.NetID + "-" + input.Network.String()
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
||||
routePeerStatuses := make(map[string]routerPeerStatus)
|
||||
for _, r := range c.routes {
|
||||
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
||||
if err != nil {
|
||||
log.Debugf("couldn't fetch peer state: %v", err)
|
||||
continue
|
||||
}
|
||||
routePeerStatuses[r.ID] = routerPeerStatus{
|
||||
connected: peerStatus.ConnStatus == peer.StatusConnected.String(),
|
||||
relayed: peerStatus.Relayed,
|
||||
direct: peerStatus.Direct,
|
||||
}
|
||||
}
|
||||
return routePeerStatuses
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
||||
var chosen string
|
||||
chosenScore := 0
|
||||
|
||||
currID := ""
|
||||
if c.chosenRoute != nil {
|
||||
currID = c.chosenRoute.ID
|
||||
}
|
||||
|
||||
for _, r := range c.routes {
|
||||
tempScore := 0
|
||||
peerStatus, found := routePeerStatuses[r.ID]
|
||||
if !found || !peerStatus.connected {
|
||||
continue
|
||||
}
|
||||
if r.Metric < route.MaxMetric {
|
||||
metricDiff := route.MaxMetric - r.Metric
|
||||
tempScore = metricDiff * 10
|
||||
}
|
||||
if !peerStatus.relayed {
|
||||
tempScore++
|
||||
}
|
||||
if !peerStatus.direct {
|
||||
tempScore++
|
||||
}
|
||||
if tempScore > chosenScore || (tempScore == chosenScore && currID == r.ID) {
|
||||
chosen = r.ID
|
||||
chosenScore = tempScore
|
||||
}
|
||||
}
|
||||
|
||||
if chosen == "" {
|
||||
var peers []string
|
||||
for _, r := range c.routes {
|
||||
peers = append(peers, r.Peer)
|
||||
}
|
||||
log.Warnf("no route was chosen for network %s because no peers from list %s were connected", c.network, peers)
|
||||
} else if chosen != currID {
|
||||
log.Infof("new chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore)
|
||||
}
|
||||
|
||||
return chosen
|
||||
}
|
||||
|
||||
func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-closer:
|
||||
return
|
||||
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil || state.ConnStatus == peer.StatusConnecting.String() {
|
||||
continue
|
||||
}
|
||||
peerStateUpdate <- struct{}{}
|
||||
log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
for _, r := range c.routes {
|
||||
_, found := c.routePeersNotifiers[r.Peer]
|
||||
if !found {
|
||||
c.routePeersNotifiers[r.Peer] = make(chan struct{})
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil || state.ConnStatus != peer.StatusConnected.String() {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if c.chosenRoute != nil {
|
||||
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.GetAddress().IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
||||
c.network, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
|
||||
var err error
|
||||
|
||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||
|
||||
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
if chosen == "" {
|
||||
err = c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.chosenRoute = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
||||
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if c.chosenRoute != nil {
|
||||
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.wgInterface.GetAddress().IP.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
c.chosenRoute = c.routes[chosen]
|
||||
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
|
||||
if err != nil {
|
||||
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *clientNetwork) handleUpdate(update routesUpdate) {
|
||||
updateMap := make(map[string]*route.Route)
|
||||
|
||||
for _, r := range update.routes {
|
||||
updateMap[r.ID] = r
|
||||
}
|
||||
|
||||
for id, r := range c.routes {
|
||||
_, found := updateMap[id]
|
||||
if !found {
|
||||
close(c.routePeersNotifiers[r.Peer])
|
||||
delete(c.routePeersNotifiers, r.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
c.routes = updateMap
|
||||
}
|
||||
|
||||
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
|
||||
// All the processing related to the client network should be done here. Thread-safe.
|
||||
func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
log.Debugf("stopping watcher for network %s", c.network)
|
||||
err := c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
case <-c.peerStateUpdate:
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
case update := <-c.routeUpdate:
|
||||
if update.updateSerial < c.updateSerial {
|
||||
log.Warnf("received a routes update with smaller serial number, ignoring it")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("received a new client network route update for %s", c.network)
|
||||
|
||||
c.handleUpdate(update)
|
||||
|
||||
c.updateSerial = update.updateSerial
|
||||
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
c.startPeersStatusChangeWatcher()
|
||||
}
|
||||
}
|
||||
}
|
||||
75
client/internal/routemanager/common_linux_test.go
Normal file
75
client/internal/routemanager/common_linux_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package routemanager
|
||||
|
||||
var insertRuleTestCases = []struct {
|
||||
name string
|
||||
inputPair routerPair
|
||||
ipVersion string
|
||||
}{
|
||||
{
|
||||
name: "Insert Forwarding IPV4 Rule",
|
||||
inputPair: routerPair{
|
||||
ID: "zxa",
|
||||
source: "100.100.100.1/32",
|
||||
destination: "100.100.200.0/24",
|
||||
masquerade: false,
|
||||
},
|
||||
ipVersion: ipv4,
|
||||
},
|
||||
{
|
||||
name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
inputPair: routerPair{
|
||||
ID: "zxa",
|
||||
source: "100.100.100.1/32",
|
||||
destination: "100.100.200.0/24",
|
||||
masquerade: true,
|
||||
},
|
||||
ipVersion: ipv4,
|
||||
},
|
||||
{
|
||||
name: "Insert Forwarding IPV6 Rule",
|
||||
inputPair: routerPair{
|
||||
ID: "zxa",
|
||||
source: "fc00::1/128",
|
||||
destination: "fc12::/64",
|
||||
masquerade: false,
|
||||
},
|
||||
ipVersion: ipv6,
|
||||
},
|
||||
{
|
||||
name: "Insert Forwarding And Nat IPV6 Rules",
|
||||
inputPair: routerPair{
|
||||
ID: "zxa",
|
||||
source: "fc00::1/128",
|
||||
destination: "fc12::/64",
|
||||
masquerade: true,
|
||||
},
|
||||
ipVersion: ipv6,
|
||||
},
|
||||
}
|
||||
|
||||
var removeRuleTestCases = []struct {
|
||||
name string
|
||||
inputPair routerPair
|
||||
ipVersion string
|
||||
}{
|
||||
{
|
||||
name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
inputPair: routerPair{
|
||||
ID: "zxa",
|
||||
source: "100.100.100.1/32",
|
||||
destination: "100.100.200.0/24",
|
||||
masquerade: true,
|
||||
},
|
||||
ipVersion: ipv4,
|
||||
},
|
||||
{
|
||||
name: "Remove Forwarding And Nat IPV6 Rules",
|
||||
inputPair: routerPair{
|
||||
ID: "zxa",
|
||||
source: "fc00::1/128",
|
||||
destination: "fc12::/64",
|
||||
masquerade: true,
|
||||
},
|
||||
ipVersion: ipv6,
|
||||
},
|
||||
}
|
||||
12
client/internal/routemanager/firewall.go
Normal file
12
client/internal/routemanager/firewall.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package routemanager
|
||||
|
||||
type firewallManager interface {
|
||||
// RestoreOrCreateContainers restores or creates a firewall container set of rules, tables and default rules
|
||||
RestoreOrCreateContainers() error
|
||||
// InsertRoutingRules inserts a routing firewall rule
|
||||
InsertRoutingRules(pair routerPair) error
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
RemoveRoutingRules(pair routerPair) error
|
||||
// CleanRoutingRules cleans a firewall set of containers
|
||||
CleanRoutingRules()
|
||||
}
|
||||
55
client/internal/routemanager/firewall_linux.go
Normal file
55
client/internal/routemanager/firewall_linux.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
import "github.com/google/nftables"
|
||||
|
||||
const (
|
||||
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
|
||||
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
|
||||
ipv6Nat = "netbird-rt-ipv6-nat"
|
||||
ipv4Nat = "netbird-rt-ipv4-nat"
|
||||
natFormat = "netbird-nat-%s"
|
||||
forwardingFormat = "netbird-fwd-%s"
|
||||
ipv6 = "ipv6"
|
||||
ipv4 = "ipv4"
|
||||
)
|
||||
|
||||
func genKey(format string, input string) string {
|
||||
return fmt.Sprintf(format, input)
|
||||
}
|
||||
|
||||
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||
func NewFirewall(parentCTX context.Context) firewallManager {
|
||||
ctx, cancel := context.WithCancel(parentCTX)
|
||||
|
||||
if isIptablesSupported() {
|
||||
log.Debugf("iptables is supported")
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
|
||||
return &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("iptables is not supported, using nftables")
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
27
client/internal/routemanager/firewall_nonlinux.go
Normal file
27
client/internal/routemanager/firewall_nonlinux.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package routemanager
|
||||
|
||||
import "context"
|
||||
|
||||
type unimplementedFirewall struct{}
|
||||
|
||||
func (unimplementedFirewall) RestoreOrCreateContainers() error {
|
||||
return nil
|
||||
}
|
||||
func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error {
|
||||
return nil
|
||||
}
|
||||
func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (unimplementedFirewall) CleanRoutingRules() {
|
||||
return
|
||||
}
|
||||
|
||||
// NewFirewall returns an unimplemented Firewall manager
|
||||
func NewFirewall(parentCtx context.Context) firewallManager {
|
||||
return unimplementedFirewall{}
|
||||
}
|
||||
403
client/internal/routemanager/iptables_linux.go
Normal file
403
client/internal/routemanager/iptables_linux.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
_, err4 := exec.LookPath("iptables")
|
||||
_, err6 := exec.LookPath("ip6tables")
|
||||
return err4 == nil && err6 == nil
|
||||
}
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
iptablesFilterTable = "filter"
|
||||
iptablesNatTable = "nat"
|
||||
iptablesForwardChain = "FORWARD"
|
||||
iptablesPostRoutingChain = "POSTROUTING"
|
||||
iptablesRoutingNatChain = "NETBIRD-RT-NAT"
|
||||
iptablesRoutingForwardingChain = "NETBIRD-RT-FWD"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
var (
|
||||
iptablesDefaultForwardingRule = []string{"-j", iptablesRoutingForwardingChain, "-m", "comment", "--comment"}
|
||||
iptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"}
|
||||
iptablesDefaultNatRule = []string{"-j", iptablesRoutingNatChain, "-m", "comment", "--comment"}
|
||||
iptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"}
|
||||
)
|
||||
|
||||
type iptablesManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
ipv4Client *iptables.IPTables
|
||||
ipv6Client *iptables.IPTables
|
||||
rules map[string]map[string][]string
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
||||
func (i *iptablesManager) CleanRoutingRules() {
|
||||
i.mux.Lock()
|
||||
defer i.mux.Unlock()
|
||||
|
||||
err := i.cleanJumpRules()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
log.Debug("flushing tables")
|
||||
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
log.Info("done cleaning up iptables rules")
|
||||
}
|
||||
|
||||
// RestoreOrCreateContainers restores existing iptables containers (chains and rules)
|
||||
// if they don't exist, we create them
|
||||
func (i *iptablesManager) RestoreOrCreateContainers() error {
|
||||
i.mux.Lock()
|
||||
defer i.mux.Unlock()
|
||||
|
||||
if i.rules[ipv4][ipv4Forwarding] != nil && i.rules[ipv6][ipv6Forwarding] != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
||||
|
||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv4Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv6Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
||||
}
|
||||
|
||||
err = i.addJumpRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addJumpRules create jump rules to send packets to NetBird chains
|
||||
func (i *iptablesManager) addJumpRules() error {
|
||||
err := i.cleanJumpRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.rules[ipv4][ipv4Forwarding] = rule
|
||||
|
||||
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
||||
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4][ipv4Nat] = rule
|
||||
|
||||
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Forwarding] = rule
|
||||
|
||||
rule = append(iptablesDefaultNatRule, ipv6Nat)
|
||||
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
||||
func (i *iptablesManager) cleanJumpRules() error {
|
||||
var err error
|
||||
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
||||
rule, found := i.rules[ipv4][ipv4Forwarding]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4][ipv4Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv6][ipv6Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func iptablesProtoToString(proto iptables.Protocol) string {
|
||||
if proto == iptables.ProtocolIPv6 {
|
||||
return ipv6
|
||||
}
|
||||
return ipv4
|
||||
}
|
||||
|
||||
// restoreRules restores existing NetBird rules
|
||||
func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error {
|
||||
ipVersion := iptablesProtoToString(iptablesClient.Proto())
|
||||
|
||||
if i.rules[ipVersion] == nil {
|
||||
i.rules[ipVersion] = make(map[string][]string)
|
||||
}
|
||||
table := iptablesFilterTable
|
||||
for _, chain := range []string{iptablesForwardChain, iptablesRoutingForwardingChain} {
|
||||
rules, err := iptablesClient.List(table, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, ruleString := range rules {
|
||||
rule := strings.Fields(ruleString)
|
||||
id := getRuleRouteID(rule)
|
||||
if id != "" {
|
||||
i.rules[ipVersion][id] = rule[2:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
table = iptablesNatTable
|
||||
for _, chain := range []string{iptablesPostRoutingChain, iptablesRoutingNatChain} {
|
||||
rules, err := iptablesClient.List(table, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, ruleString := range rules {
|
||||
rule := strings.Fields(ruleString)
|
||||
id := getRuleRouteID(rule)
|
||||
if id != "" {
|
||||
i.rules[ipVersion][id] = rule[2:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createChain create NetBird chains
|
||||
func createChain(iptables *iptables.IPTables, table, newChain string) error {
|
||||
chains, err := iptables.ListChains(table)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptablesProtoToString(iptables.Proto()), table, err)
|
||||
}
|
||||
|
||||
shouldCreateChain := true
|
||||
for _, chain := range chains {
|
||||
if chain == newChain {
|
||||
shouldCreateChain = false
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCreateChain {
|
||||
err = iptables.NewChain(table, newChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptablesProtoToString(iptables.Proto()), newChain, table, err)
|
||||
}
|
||||
|
||||
if table == iptablesNatTable {
|
||||
err = iptables.Append(table, newChain, iptablesDefaultNetbirdNatRule...)
|
||||
} else {
|
||||
err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptablesProtoToString(iptables.Proto()), newChain, err)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification with comment identifier
|
||||
func genRuleSpec(jump, id, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
|
||||
}
|
||||
|
||||
// getRuleRouteID returns the rule ID if matches our prefix
|
||||
func getRuleRouteID(rule []string) string {
|
||||
for i, flag := range rule {
|
||||
if flag == "--comment" {
|
||||
id := rule[i+1]
|
||||
if strings.HasPrefix(id, "netbird-") {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
|
||||
i.mux.Lock()
|
||||
defer i.mux.Unlock()
|
||||
|
||||
var err error
|
||||
prefix := netip.MustParsePrefix(pair.source)
|
||||
ipVersion := ipv4
|
||||
iptablesClient := i.ipv4Client
|
||||
if prefix.Addr().Unmap().Is6() {
|
||||
iptablesClient = i.ipv6Client
|
||||
ipVersion = ipv6
|
||||
}
|
||||
|
||||
forwardRuleKey := genKey(forwardingFormat, pair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination)
|
||||
existingRule, found := i.rules[ipVersion][forwardRuleKey]
|
||||
if found {
|
||||
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
|
||||
}
|
||||
delete(i.rules[ipVersion], forwardRuleKey)
|
||||
}
|
||||
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err)
|
||||
}
|
||||
|
||||
i.rules[ipVersion][forwardRuleKey] = forwardRule
|
||||
|
||||
if !pair.masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
natRuleKey := genKey(natFormat, pair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination)
|
||||
existingRule, found = i.rules[ipVersion][natRuleKey]
|
||||
if found {
|
||||
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err)
|
||||
}
|
||||
delete(i.rules[ipVersion], natRuleKey)
|
||||
}
|
||||
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err)
|
||||
}
|
||||
|
||||
i.rules[ipVersion][natRuleKey] = natRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
||||
func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
|
||||
i.mux.Lock()
|
||||
defer i.mux.Unlock()
|
||||
|
||||
var err error
|
||||
prefix := netip.MustParsePrefix(pair.source)
|
||||
ipVersion := ipv4
|
||||
iptablesClient := i.ipv4Client
|
||||
if prefix.Addr().Unmap().Is6() {
|
||||
iptablesClient = i.ipv6Client
|
||||
ipVersion = ipv6
|
||||
}
|
||||
|
||||
forwardRuleKey := genKey(forwardingFormat, pair.ID)
|
||||
existingRule, found := i.rules[ipVersion][forwardRuleKey]
|
||||
if found {
|
||||
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
|
||||
}
|
||||
}
|
||||
delete(i.rules[ipVersion], forwardRuleKey)
|
||||
|
||||
if !pair.masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
natRuleKey := genKey(natFormat, pair.ID)
|
||||
existingRule, found = i.rules[ipVersion][natRuleKey]
|
||||
if found {
|
||||
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err)
|
||||
}
|
||||
}
|
||||
delete(i.rules[ipVersion], natRuleKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
247
client/internal/routemanager/iptables_linux_test.go
Normal file
247
client/internal/routemanager/iptables_linux_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
|
||||
manager := &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")
|
||||
|
||||
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
||||
|
||||
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
||||
|
||||
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
pair := routerPair{
|
||||
ID: "abc",
|
||||
source: "100.100.100.1/32",
|
||||
destination: "100.100.100.0/24",
|
||||
masquerade: true,
|
||||
}
|
||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
pair = routerPair{
|
||||
ID: "abc",
|
||||
source: "fc00::1/128",
|
||||
destination: "fc11::/64",
|
||||
masquerade: true,
|
||||
}
|
||||
|
||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
||||
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
||||
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
delete(manager.rules, ipv4)
|
||||
delete(manager.rules, ipv6)
|
||||
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
require.Len(t, manager.rules[ipv4], 4, "should have restored all rules for ipv4")
|
||||
|
||||
foundRule, found := manager.rules[ipv4][forward4RuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the map")
|
||||
require.Equal(t, forward4Rule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||
|
||||
foundRule, found = manager.rules[ipv4][nat4RuleKey]
|
||||
require.True(t, found, "nat rule should exist in the map")
|
||||
require.Equal(t, nat4Rule[:4], foundRule[:4], "stored nat rule should match")
|
||||
|
||||
require.Len(t, manager.rules[ipv6], 4, "should have restored all rules for ipv6")
|
||||
|
||||
foundRule, found = manager.rules[ipv6][forward6RuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the map")
|
||||
require.Equal(t, forward6Rule[:4], foundRule[:4], "stored forward rule should match")
|
||||
|
||||
foundRule, found = manager.rules[ipv6][nat6RuleKey]
|
||||
require.True(t, found, "nat rule should exist in the map")
|
||||
require.Equal(t, nat6Rule[:4], foundRule[:4], "stored nat rule should match")
|
||||
}
|
||||
|
||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
for _, testCase := range insertRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
iptablesClient := ipv4Client
|
||||
if testCase.ipVersion == ipv6 {
|
||||
iptablesClient = ipv6Client
|
||||
}
|
||||
|
||||
manager := &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
||||
|
||||
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
foundRule, found := manager.rules[testCase.ipVersion][forwardRuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||
|
||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if testCase.inputPair.masquerade {
|
||||
require.True(t, exists, "nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
|
||||
require.True(t, foundNat, "nat rule should exist in the map")
|
||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
|
||||
require.False(t, foundNat, "nat rule should exist in the map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
for _, testCase := range removeRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
iptablesClient := ipv4Client
|
||||
if testCase.ipVersion == ipv6 {
|
||||
iptablesClient = ipv6Client
|
||||
}
|
||||
|
||||
manager := &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
||||
|
||||
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
||||
|
||||
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
delete(manager.rules, ipv4)
|
||||
delete(manager.rules, ipv6)
|
||||
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveRoutingRules(testCase.inputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
require.False(t, exists, "forwarding rule should not exist")
|
||||
|
||||
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
|
||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
||||
require.False(t, exists, "nat rule should not exist")
|
||||
|
||||
_, found = manager.rules[testCase.ipVersion][natRuleKey]
|
||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
181
client/internal/routemanager/manager.go
Normal file
181
client/internal/routemanager/manager.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||
Stop()
|
||||
}
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
serverRoutes map[string]*route.Route
|
||||
serverRouter *serverRouter
|
||||
statusRecorder *status.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
}
|
||||
|
||||
// NewManager returns a new route manager
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
return &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
serverRoutes: make(map[string]*route.Route),
|
||||
serverRouter: &serverRouter{
|
||||
routes: make(map[string]*route.Route),
|
||||
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
||||
firewall: NewFirewall(ctx),
|
||||
},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop() {
|
||||
m.stop()
|
||||
m.serverRouter.firewall.CleanRoutingRules()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
update := routesUpdate{
|
||||
updateSerial: updateSerial,
|
||||
routes: routes,
|
||||
}
|
||||
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
|
||||
if len(routesMap) > 0 {
|
||||
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for routeID := range m.serverRoutes {
|
||||
update, found := routesMap[routeID]
|
||||
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range serverRoutesToRemove {
|
||||
oldRoute := m.serverRoutes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.serverRoutes, routeID)
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.serverRoutes[id]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.serverRoutes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.serverRoutes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||
newServerRoutesMap := make(map[string]*route.Route)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
// only linux is supported for now
|
||||
if newRoute.Peer == m.pubKey {
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||
continue
|
||||
}
|
||||
newServerRoutesMap[newRoute.ID] = newRoute
|
||||
} else {
|
||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||
// we skip this route management
|
||||
if newRoute.Network.Bits() < 7 {
|
||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||
system.NetbirdVersion(), newRoute.Network)
|
||||
continue
|
||||
}
|
||||
clientNetworkID := getClientNetworkID(newRoute)
|
||||
newClientRoutesIDMap[clientNetworkID] = append(newClientRoutesIDMap[clientNetworkID], newRoute)
|
||||
}
|
||||
}
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
|
||||
err := m.updateServerRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
370
client/internal/routemanager/manager_test.go
Normal file
370
client/internal/routemanager/manager_test.go
Normal file
@@ -0,0 +1,370 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small
|
||||
// if linux host, should have one for server in map
|
||||
// we should have 2 client manager
|
||||
// 2 ranges in our routing table
|
||||
|
||||
const localPeerKey = "local"
|
||||
const remotePeerKey1 = "remote1"
|
||||
const remotePeerKey2 = "remote1"
|
||||
|
||||
func TestManagerUpdateRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputInitRoutes []*route.Route
|
||||
inputRoutes []*route.Route
|
||||
inputSerial uint64
|
||||
shouldCheckServerRoutes bool
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
}{
|
||||
{
|
||||
name: "Should create 2 client networks",
|
||||
inputInitRoutes: []*route.Route{},
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.8.8/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 2,
|
||||
},
|
||||
{
|
||||
name: "Should Create 2 Server Routes",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.252.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("8.8.8.9/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: runtime.GOOS == "linux",
|
||||
serverRoutesExpected: 2,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 Route For Client And Server",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.9.9/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: runtime.GOOS == "linux",
|
||||
serverRoutesExpected: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 HA Route and 1 Standalone",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.20.0/24"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey2,
|
||||
Network: netip.MustParsePrefix("8.8.20.0/24"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "c",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.9.9/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 2,
|
||||
},
|
||||
{
|
||||
name: "No Small Client Route Should Be Added",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "No Server Routes Should Be Added To Non Linux",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: runtime.GOOS != "linux",
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "Remove 1 Client Route",
|
||||
inputInitRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.8.8/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Update Route to HA",
|
||||
inputInitRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.8.8/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey2,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Remove Client Routes",
|
||||
inputInitRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.8.8/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputRoutes: []*route.Route{},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "Remove All Routes",
|
||||
inputInitRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.8.8/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputRoutes: []*route.Route{},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: true,
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
statusRecorder := status.NewRecorder()
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
|
||||
defer routeManager.Stop()
|
||||
|
||||
if len(testCase.inputInitRoutes) > 0 {
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes with init routes")
|
||||
}
|
||||
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||
|
||||
if testCase.shouldCheckServerRoutes {
|
||||
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
27
client/internal/routemanager/mock.go
Normal file
27
client/internal/routemanager/mock.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||
StopFunc func()
|
||||
}
|
||||
|
||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
if m.UpdateRoutesFunc != nil {
|
||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||
}
|
||||
return fmt.Errorf("method UpdateRoutes is not implemented")
|
||||
}
|
||||
|
||||
// Stop mock implementation of Stop from Manager interface
|
||||
func (m *MockManager) Stop() {
|
||||
if m.StopFunc != nil {
|
||||
m.StopFunc()
|
||||
}
|
||||
}
|
||||
386
client/internal/routemanager/nftables_linux.go
Normal file
386
client/internal/routemanager/nftables_linux.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
import "github.com/google/nftables"
|
||||
|
||||
//
|
||||
const (
|
||||
nftablesTable = "netbird-rt"
|
||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
||||
nftablesRoutingNatChain = "netbird-rt-nat"
|
||||
)
|
||||
|
||||
// constants needed to create nftable rules
|
||||
const (
|
||||
ipv4Len = 4
|
||||
ipv4SrcOffset = 12
|
||||
ipv4DestOffset = 16
|
||||
ipv6Len = 16
|
||||
ipv6SrcOffset = 8
|
||||
ipv6DestOffset = 24
|
||||
exprDirectionSource = "source"
|
||||
exprDirectionDestination = "destination"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
var (
|
||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||
|
||||
zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...)
|
||||
|
||||
exprAllowRelatedEstablished = []expr.Any{
|
||||
&expr.Ct{
|
||||
Register: 1,
|
||||
SourceRegister: false,
|
||||
Key: 0,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []uint8{0x6, 0x0, 0x0, 0x0},
|
||||
Xor: zeroXor,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
exprCounterAccept = []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type nftablesManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
chains map[string]map[string]*nftables.Chain
|
||||
rules map[string]*nftables.Rule
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// CleanRoutingRules cleans existing nftables rules from the system
|
||||
func (n *nftablesManager) CleanRoutingRules() {
|
||||
n.mux.Lock()
|
||||
defer n.mux.Unlock()
|
||||
log.Debug("flushing tables")
|
||||
if n.tableIPv4 != nil && n.tableIPv6 != nil {
|
||||
n.conn.FlushTable(n.tableIPv6)
|
||||
n.conn.FlushTable(n.tableIPv4)
|
||||
}
|
||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
||||
}
|
||||
|
||||
// RestoreOrCreateContainers restores existing nftables containers (tables and chains)
|
||||
// if they don't exist, we create them
|
||||
func (n *nftablesManager) RestoreOrCreateContainers() error {
|
||||
n.mux.Lock()
|
||||
defer n.mux.Unlock()
|
||||
|
||||
if n.tableIPv6 != nil && n.tableIPv4 != nil {
|
||||
log.Debugf("nftables: containers already restored, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
tables, err := n.conn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == nftablesTable {
|
||||
if table.Family == nftables.TableFamilyIPv4 {
|
||||
n.tableIPv4 = table
|
||||
continue
|
||||
}
|
||||
n.tableIPv6 = table
|
||||
}
|
||||
}
|
||||
|
||||
if n.tableIPv4 == nil {
|
||||
n.tableIPv4 = n.conn.AddTable(&nftables.Table{
|
||||
Name: nftablesTable,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
})
|
||||
}
|
||||
|
||||
if n.tableIPv6 == nil {
|
||||
n.tableIPv6 = n.conn.AddTable(&nftables.Table{
|
||||
Name: nftablesTable,
|
||||
Family: nftables.TableFamilyIPv6,
|
||||
})
|
||||
}
|
||||
|
||||
chains, err := n.conn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list chains: %v", err)
|
||||
}
|
||||
|
||||
n.chains[ipv4] = make(map[string]*nftables.Chain)
|
||||
n.chains[ipv6] = make(map[string]*nftables.Chain)
|
||||
|
||||
for _, chain := range chains {
|
||||
switch {
|
||||
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4:
|
||||
n.chains[ipv4][chain.Name] = chain
|
||||
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6:
|
||||
n.chains[ipv6][chain.Name] = chain
|
||||
}
|
||||
}
|
||||
|
||||
if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found {
|
||||
n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
|
||||
Name: nftablesRoutingForwardingChain,
|
||||
Table: n.tableIPv4,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityNATDest + 1,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
}
|
||||
|
||||
if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found {
|
||||
n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
|
||||
Name: nftablesRoutingNatChain,
|
||||
Table: n.tableIPv4,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
}
|
||||
|
||||
if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found {
|
||||
n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
|
||||
Name: nftablesRoutingForwardingChain,
|
||||
Table: n.tableIPv6,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityNATDest + 1,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
}
|
||||
|
||||
if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found {
|
||||
n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
|
||||
Name: nftablesRoutingNatChain,
|
||||
Table: n.tableIPv6,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
}
|
||||
|
||||
err = n.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n.checkOrCreateDefaultForwardingRules()
|
||||
err = n.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (n *nftablesManager) refreshRulesMap() error {
|
||||
for _, registeredChains := range n.chains {
|
||||
for _, chain := range registeredChains {
|
||||
rules, err := n.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
n.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
||||
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
||||
_, foundIPv4 := n.rules[ipv4Forwarding]
|
||||
if !foundIPv4 {
|
||||
n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{
|
||||
Table: n.tableIPv4,
|
||||
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
|
||||
Exprs: exprAllowRelatedEstablished,
|
||||
UserData: []byte(ipv4Forwarding),
|
||||
})
|
||||
}
|
||||
|
||||
_, foundIPv6 := n.rules[ipv6Forwarding]
|
||||
if !foundIPv6 {
|
||||
n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{
|
||||
Table: n.tableIPv6,
|
||||
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
|
||||
Exprs: exprAllowRelatedEstablished,
|
||||
UserData: []byte(ipv6Forwarding),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
||||
n.mux.Lock()
|
||||
defer n.mux.Unlock()
|
||||
|
||||
prefix := netip.MustParsePrefix(pair.source)
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
||||
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
||||
|
||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
||||
fwdKey := genKey(forwardingFormat, pair.ID)
|
||||
if prefix.Addr().Unmap().Is4() {
|
||||
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
|
||||
Table: n.tableIPv4,
|
||||
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(fwdKey),
|
||||
})
|
||||
} else {
|
||||
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
|
||||
Table: n.tableIPv6,
|
||||
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(fwdKey),
|
||||
})
|
||||
}
|
||||
|
||||
if pair.masquerade {
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
||||
natKey := genKey(natFormat, pair.ID)
|
||||
|
||||
if prefix.Addr().Unmap().Is4() {
|
||||
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
|
||||
Table: n.tableIPv4,
|
||||
Chain: n.chains[ipv4][nftablesRoutingNatChain],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natKey),
|
||||
})
|
||||
} else {
|
||||
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
|
||||
Table: n.tableIPv6,
|
||||
Chain: n.chains[ipv6][nftablesRoutingNatChain],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natKey),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
err := n.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
||||
func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
||||
n.mux.Lock()
|
||||
defer n.mux.Unlock()
|
||||
|
||||
err := n.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fwdKey := genKey(forwardingFormat, pair.ID)
|
||||
natKey := genKey(natFormat, pair.ID)
|
||||
fwdRule, found := n.rules[fwdKey]
|
||||
if found {
|
||||
err = n.conn.DelRule(fwdRule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err)
|
||||
}
|
||||
log.Debugf("nftables: removing forwarding rule for %s", pair.destination)
|
||||
delete(n.rules, fwdKey)
|
||||
}
|
||||
natRule, found := n.rules[natKey]
|
||||
if found {
|
||||
err = n.conn.DelRule(natRule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err)
|
||||
}
|
||||
log.Debugf("nftables: removing nat rule for %s", pair.destination)
|
||||
delete(n.rules, natKey)
|
||||
}
|
||||
err = n.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
||||
}
|
||||
log.Debugf("nftables: removed rules for %s", pair.destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPayloadDirectives get expression directives based on ip version and direction
|
||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||
switch {
|
||||
case direction == exprDirectionSource && isIPv4:
|
||||
return ipv4SrcOffset, ipv4Len, zeroXor
|
||||
case direction == exprDirectionDestination && isIPv4:
|
||||
return ipv4DestOffset, ipv4Len, zeroXor
|
||||
case direction == exprDirectionSource && isIPv6:
|
||||
return ipv6SrcOffset, ipv6Len, zeroXor6
|
||||
case direction == exprDirectionDestination && isIPv6:
|
||||
return ipv6DestOffset, ipv6Len, zeroXor6
|
||||
default:
|
||||
panic("no matched payload directive")
|
||||
}
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any {
|
||||
ip, network, _ := net.ParseCIDR(cidr)
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6())
|
||||
|
||||
return []expr.Any{
|
||||
// fetch src add
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offSet,
|
||||
Len: packetLen,
|
||||
},
|
||||
// net mask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: packetLen,
|
||||
Mask: network.Mask,
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
270
client/internal/routemanager/nftables_linux_test.go
Normal file
270
client/internal/routemanager/nftables_linux_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
|
||||
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
|
||||
require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6")
|
||||
|
||||
pair := routerPair{
|
||||
ID: "abc",
|
||||
source: "100.100.100.1/32",
|
||||
destination: "100.100.100.0/24",
|
||||
masquerade: true,
|
||||
}
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
||||
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
||||
|
||||
forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
||||
inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.tableIPv4,
|
||||
Chain: manager.chains[ipv4][nftablesRoutingForwardingChain],
|
||||
Exprs: forward4Exp,
|
||||
UserData: []byte(forward4RuleKey),
|
||||
})
|
||||
|
||||
nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
||||
|
||||
inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.tableIPv4,
|
||||
Chain: manager.chains[ipv4][nftablesRoutingNatChain],
|
||||
Exprs: nat4Exp,
|
||||
UserData: []byte(nat4RuleKey),
|
||||
})
|
||||
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
pair = routerPair{
|
||||
ID: "xyz",
|
||||
source: "fc00::1/128",
|
||||
destination: "fc11::/64",
|
||||
masquerade: true,
|
||||
}
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions("source", pair.source)
|
||||
destExp = generateCIDRMatcherExpressions("destination", pair.destination)
|
||||
|
||||
forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
||||
inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.tableIPv6,
|
||||
Chain: manager.chains[ipv6][nftablesRoutingForwardingChain],
|
||||
Exprs: forward6Exp,
|
||||
UserData: []byte(forward6RuleKey),
|
||||
})
|
||||
|
||||
nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
||||
|
||||
inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.tableIPv6,
|
||||
Chain: manager.chains[ipv6][nftablesRoutingNatChain],
|
||||
Exprs: nat6Exp,
|
||||
UserData: []byte(nat6RuleKey),
|
||||
})
|
||||
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
manager.tableIPv4 = nil
|
||||
manager.tableIPv6 = nil
|
||||
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
|
||||
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
|
||||
require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6")
|
||||
|
||||
foundRule, found := manager.rules[forward4RuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the map")
|
||||
assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match")
|
||||
|
||||
foundRule, found = manager.rules[nat4RuleKey]
|
||||
require.True(t, found, "nat rule should exist in the map")
|
||||
// match len of output as nftables client doesn't return expressions with masquerade expression
|
||||
assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match")
|
||||
|
||||
foundRule, found = manager.rules[forward6RuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the map")
|
||||
assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match")
|
||||
|
||||
foundRule, found = manager.rules[nat6RuleKey]
|
||||
require.True(t, found, "nat rule should exist in the map")
|
||||
// match len of output as nftables client doesn't return expressions with masquerade expression
|
||||
assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match")
|
||||
}
|
||||
|
||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
for _, testCase := range insertRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
|
||||
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
|
||||
testingExpression := append(sourceExp, destExp...)
|
||||
fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
||||
|
||||
found := 0
|
||||
for _, registeredChains := range manager.chains {
|
||||
for _, chain := range registeredChains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.inputPair.masquerade {
|
||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
||||
found := 0
|
||||
for _, registeredChains := range manager.chains {
|
||||
for _, chain := range registeredChains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
for _, testCase := range removeRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
table := manager.tableIPv4
|
||||
if testCase.ipVersion == ipv6 {
|
||||
table = manager.tableIPv6
|
||||
}
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
|
||||
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
|
||||
|
||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
||||
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(forwardRuleKey),
|
||||
})
|
||||
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natRuleKey),
|
||||
})
|
||||
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
manager.tableIPv4 = nil
|
||||
manager.tableIPv6 = nil
|
||||
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveRoutingRules(testCase.inputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
for _, registeredChains := range manager.chains {
|
||||
for _, chain := range registeredChains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist")
|
||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
67
client/internal/routemanager/server.go
Normal file
67
client/internal/routemanager/server.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
routes map[string]*route.Route
|
||||
// best effort to keep net forward configuration as it was
|
||||
netForwardHistoryEnabled bool
|
||||
mux sync.Mutex
|
||||
firewall firewallManager
|
||||
}
|
||||
|
||||
type routerPair struct {
|
||||
ID string
|
||||
source string
|
||||
destination string
|
||||
masquerade bool
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) routerPair {
|
||||
parsed := netip.MustParsePrefix(source).Masked()
|
||||
return routerPair{
|
||||
ID: route.ID,
|
||||
source: parsed.String(),
|
||||
destination: route.Network.Masked().String(),
|
||||
masquerade: route.Masquerade,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.serverRouter.mux.Lock()
|
||||
defer m.serverRouter.mux.Unlock()
|
||||
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(m.serverRouter.routes, route.ID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.serverRouter.mux.Lock()
|
||||
defer m.serverRouter.mux.Unlock()
|
||||
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.serverRouter.routes[route.ID] = route
|
||||
return nil
|
||||
}
|
||||
}
|
||||
55
client/internal/routemanager/systemops.go
Normal file
55
client/internal/routemanager/systemops.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
var errRouteNotFound = fmt.Errorf("route not found")
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||
if err != nil && err != errRouteNotFound {
|
||||
return err
|
||||
}
|
||||
prefixGateway, err := getExistingRIBRouteGateway(prefix)
|
||||
if err != nil && err != errRouteNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
if prefixGateway != nil && !prefixGateway.Equal(gateway) {
|
||||
log.Warnf("route for network %s already exist and is pointing to the gateway: %s, won't add another one", prefix, prefixGateway)
|
||||
return nil
|
||||
}
|
||||
return addToRouteTable(prefix, addr)
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||
addrIP := net.ParseIP(addr)
|
||||
prefixGateway, err := getExistingRIBRouteGateway(prefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if prefixGateway != nil && !prefixGateway.Equal(addrIP) {
|
||||
log.Warnf("route for network %s is pointing to a different gateway: %s, should be pointing to: %s, not removing", prefix, prefixGateway, addrIP)
|
||||
return nil
|
||||
}
|
||||
return removeFromRouteTable(prefix)
|
||||
}
|
||||
|
||||
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, localGatewayAddress, err := r.Route(prefix.Addr().AsSlice())
|
||||
if err != nil {
|
||||
log.Errorf("getting routes returned an error: %v", err)
|
||||
return nil, errRouteNotFound
|
||||
}
|
||||
|
||||
return localGatewayAddress, nil
|
||||
}
|
||||
73
client/internal/routemanager/systemops_linux.go
Normal file
73
client/internal/routemanager/systemops_linux.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"github.com/vishvananda/netlink"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addrMask := "/32"
|
||||
if prefix.Addr().Unmap().Is6() {
|
||||
addrMask = "/128"
|
||||
}
|
||||
|
||||
ip, _, err := net.ParseCIDR(addr + addrMask)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Dst: ipNet,
|
||||
Gw: ip,
|
||||
}
|
||||
|
||||
err = netlink.RouteAdd(route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
err = netlink.RouteDel(route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
err := ioutil.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
|
||||
return err
|
||||
}
|
||||
|
||||
func isNetForwardHistoryEnabled() bool {
|
||||
out, err := ioutil.ReadFile(ipv4ForwardingPath)
|
||||
if err != nil {
|
||||
// todo
|
||||
panic(err)
|
||||
}
|
||||
return string(out) == "1"
|
||||
}
|
||||
41
client/internal/routemanager/systemops_nonlinux.go
Normal file
41
client/internal/routemanager/systemops_nonlinux.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||
cmd := exec.Command("route", "add", prefix.String(), addr)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix) error {
|
||||
cmd := exec.Command("route", "delete", prefix.String())
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNetForwardHistoryEnabled() bool {
|
||||
log.Infof("check netforwad history is not implemented on %s", runtime.GOOS)
|
||||
return false
|
||||
}
|
||||
68
client/internal/routemanager/systemops_test.go
Normal file
68
client/internal/routemanager/systemops_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAddRemoveRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
shouldRouteToWireguard bool
|
||||
shouldBeRemoved bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove Route",
|
||||
prefix: netip.MustParsePrefix("100.66.120.0/24"),
|
||||
shouldRouteToWireguard: true,
|
||||
shouldBeRemoved: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Or Remove Route",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
shouldRouteToWireguard: false,
|
||||
shouldBeRemoved: false,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
||||
require.NoError(t, err, "should not return err")
|
||||
if testCase.shouldRouteToWireguard {
|
||||
require.Equal(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
} else {
|
||||
require.NotEqual(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to a different interface")
|
||||
}
|
||||
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetAddress().IP.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.shouldBeRemoved {
|
||||
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
|
||||
} else {
|
||||
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -47,17 +47,19 @@ type FullStatus struct {
|
||||
|
||||
// Status holds a state of peers, signal and management connections
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
peers map[string]PeerState
|
||||
signal SignalState
|
||||
management ManagementState
|
||||
localPeer LocalPeerState
|
||||
mux sync.Mutex
|
||||
peers map[string]PeerState
|
||||
changeNotify map[string]chan struct{}
|
||||
signal SignalState
|
||||
management ManagementState
|
||||
localPeer LocalPeerState
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
func NewRecorder() *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]PeerState),
|
||||
peers: make(map[string]PeerState),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,6 +76,18 @@ func (d *Status) AddPeer(peerPubKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeer adds peer to Daemon status map
|
||||
func (d *Status) GetPeer(peerPubKey string) (PeerState, error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return PeerState{}, errors.New("peer not found")
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// RemovePeer removes peer from Daemon status map
|
||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
@@ -113,9 +127,27 @@ func (d *Status) UpdatePeerState(receivedState PeerState) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
ch, found := d.changeNotify[peer]
|
||||
if !found || ch == nil {
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// UpdateLocalPeerState updates local peer status
|
||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Lock()
|
||||
|
||||
@@ -19,6 +19,21 @@ func TestAddPeer(t *testing.T) {
|
||||
assert.Error(t, err, "should return error on duplicate")
|
||||
}
|
||||
|
||||
func TestGetPeer(t *testing.T) {
|
||||
key := "abc"
|
||||
status := NewRecorder()
|
||||
err := status.AddPeer(key)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
peerStatus, err := status.GetPeer(key)
|
||||
assert.NoError(t, err, "shouldn't return error on getting peer")
|
||||
|
||||
assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match")
|
||||
|
||||
_, err = status.GetPeer("non_existing_key")
|
||||
assert.Error(t, err, "should return error when peer doesn't exist")
|
||||
}
|
||||
|
||||
func TestUpdatePeerState(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "10.10.10.10"
|
||||
@@ -39,6 +54,31 @@ func TestUpdatePeerState(t *testing.T) {
|
||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||
}
|
||||
|
||||
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "10.10.10.10"
|
||||
status := NewRecorder()
|
||||
peerState := PeerState{
|
||||
PubKey: key,
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
|
||||
ch := status.GetPeerStateChangeNotifier(key)
|
||||
assert.NotNil(t, ch, "channel shouldn't be nil")
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerState(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
t.Errorf("channel wasn't closed after update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemovePeer(t *testing.T) {
|
||||
key := "abc"
|
||||
status := NewRecorder()
|
||||
|
||||
18
go.mod
18
go.mod
@@ -11,7 +11,7 @@ require (
|
||||
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7 //keep this version otherwise wiretrustee up command breaks
|
||||
github.com/onsi/ginkgo v1.16.5
|
||||
github.com/onsi/gomega v1.18.1
|
||||
github.com/pion/ice/v2 v2.1.17
|
||||
github.com/pion/ice/v2 v2.2.7
|
||||
github.com/rs/cors v1.8.0
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
github.com/spf13/cobra v1.3.0
|
||||
@@ -30,16 +30,19 @@ require (
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.1.4
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/coreos/go-iptables v0.6.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/eko/gocache/v2 v2.3.1
|
||||
github.com/getlantern/systray v1.2.1
|
||||
github.com/gliderlabs/ssh v0.3.4
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/libp2p/go-netroute v0.2.0
|
||||
github.com/magiconair/properties v1.8.5
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/stretchr/testify v1.7.1
|
||||
golang.org/x/net v0.0.0-20220513224357-95641704303c
|
||||
golang.org/x/net v0.0.0-20220630215102-69896b714898
|
||||
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
|
||||
)
|
||||
|
||||
@@ -67,6 +70,7 @@ require (
|
||||
github.com/godbus/dbus/v5 v5.0.4 // indirect
|
||||
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
||||
github.com/google/go-cmp v0.5.7 // indirect
|
||||
github.com/google/gopacket v1.1.19 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.0.0 // indirect
|
||||
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
@@ -76,13 +80,13 @@ require (
|
||||
github.com/nxadm/tail v1.4.8 // indirect
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.1.2 // indirect
|
||||
github.com/pion/dtls/v2 v2.1.5 // indirect
|
||||
github.com/pion/logging v0.2.2 // indirect
|
||||
github.com/pion/mdns v0.0.5 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pion/stun v0.3.5 // indirect
|
||||
github.com/pion/transport v0.13.0 // indirect
|
||||
github.com/pion/turn/v2 v2.0.7 // indirect
|
||||
github.com/pion/transport v0.13.1 // indirect
|
||||
github.com/pion/turn/v2 v2.0.8 // indirect
|
||||
github.com/pion/udp v0.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_golang v1.12.2 // indirect
|
||||
@@ -113,6 +117,4 @@ require (
|
||||
k8s.io/apimachinery v0.23.5 // indirect
|
||||
)
|
||||
|
||||
replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb
|
||||
|
||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220901161712-56a6ec08182e
|
||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84
|
||||
|
||||
39
go.sum
39
go.sum
@@ -115,6 +115,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH
|
||||
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/coocood/freecache v1.2.1 h1:/v1CqMq45NFH9mp/Pt142reundeBM0dVUD3osQBeu/U=
|
||||
github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk=
|
||||
github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
|
||||
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
@@ -283,10 +285,14 @@ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8
|
||||
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
||||
@@ -399,6 +405,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
||||
github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE=
|
||||
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
|
||||
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
|
||||
github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w=
|
||||
github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls=
|
||||
@@ -464,8 +472,8 @@ github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8m
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
||||
github.com/netbirdio/service v0.0.0-20220901161712-56a6ec08182e h1:T7x1EzbvEiuvRtfNxLmbw5z2cwdXuFx+plt2lzY3nPY=
|
||||
github.com/netbirdio/service v0.0.0-20220901161712-56a6ec08182e/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84 h1:u8kpzR9ld1uAeH/BAXsS0SfcnhooLWeO7UgHSBVPD9I=
|
||||
github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
@@ -497,8 +505,10 @@ github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTK
|
||||
github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4=
|
||||
github.com/pegasus-kv/thrift v0.13.0/go.mod h1:Gl9NT/WHG6ABm6NsrbfE8LiJN0sAyneCrvB4qN4NPqQ=
|
||||
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
|
||||
github.com/pion/dtls/v2 v2.1.2 h1:22Q1Jk9L++Yo7BIf9130MonNPfPVb+YgdYLeyQotuAA=
|
||||
github.com/pion/dtls/v2 v2.1.2/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus=
|
||||
github.com/pion/dtls/v2 v2.1.5 h1:jlh2vtIyUBShchoTDqpCCqiYCyRFJ/lvf/gQ8TALs+c=
|
||||
github.com/pion/dtls/v2 v2.1.5/go.mod h1:BqCE7xPZbPSubGasRoDFJeTsyJtdD1FanJYL0JGheqY=
|
||||
github.com/pion/ice/v2 v2.2.7 h1:kG9tux3WdYUSqqqnf+O5zKlpy41PdlvLUBlYJeV2emQ=
|
||||
github.com/pion/ice/v2 v2.2.7/go.mod h1:Ckj7cWZ717rtU01YoDQA9ntGWCk95D42uVZ8sI0EL+8=
|
||||
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw=
|
||||
@@ -508,10 +518,11 @@ github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TB
|
||||
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
|
||||
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
|
||||
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
|
||||
github.com/pion/transport v0.13.0 h1:KWTA5ZrQogizzYwPEciGtHPLwpAjE91FgXnyu+Hv2uY=
|
||||
github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g=
|
||||
github.com/pion/turn/v2 v2.0.7 h1:SZhc00WDovK6czaN1RSiHqbwANtIO6wfZQsU0m0KNE8=
|
||||
github.com/pion/turn/v2 v2.0.7/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
|
||||
github.com/pion/transport v0.13.1 h1:/UH5yLeQtwm2VZIPjxwnNFxjS4DFhyLfS4GlfuKUzfA=
|
||||
github.com/pion/transport v0.13.1/go.mod h1:EBxbqzyv+ZrmDb82XswEE0BjfQFtuw1Nu6sjnjWCsGg=
|
||||
github.com/pion/turn/v2 v2.0.8 h1:KEstL92OUN3k5k8qxsXHpr7WWfrdp7iJZHx99ud8muw=
|
||||
github.com/pion/turn/v2 v2.0.8/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
|
||||
github.com/pion/udp v0.1.1 h1:8UAPvyqmsxK8oOjloDk4wUt63TzFe9WEJkg5lChlj7o=
|
||||
github.com/pion/udp v0.1.1/go.mod h1:6AFo+CMdKQm7UiA0eUPA8/eVCTx8jBIITLZHc9DWX5M=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
@@ -616,8 +627,6 @@ github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJ
|
||||
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
|
||||
github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb h1:CU1/+CEeCPvYXgfAyqTJXSQSf6hW3wsWM6Dfz6HkHEQ=
|
||||
github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb/go.mod h1:XT1Nrb4OxbVFPffbQMbq4PaeEkpRLVzdphh3fjrw7DY=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -654,7 +663,7 @@ golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c=
|
||||
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -748,6 +757,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8=
|
||||
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
@@ -763,8 +773,10 @@ golang.org/x/net v0.0.0-20211208012354-db4efeb81f4b/go.mod h1:9nx3DQGgdP8bBQD5qx
|
||||
golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220513224357-95641704303c h1:nF9mHSvoKBLkQNQhJZNsc66z2UzAMUbLGjC95CF3pU0=
|
||||
golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw=
|
||||
golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -870,6 +882,7 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -896,6 +909,8 @@ golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20211214234402-4825e8c3871d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU=
|
||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
|
||||
@@ -9,6 +9,16 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetName returns the interface name
|
||||
func (w *WGIface) GetName() string {
|
||||
return w.Name
|
||||
}
|
||||
|
||||
// GetAddress returns the interface address
|
||||
func (w *WGIface) GetAddress() WGAddress {
|
||||
return w.Address
|
||||
}
|
||||
|
||||
// configureDevice configures the wireguard device
|
||||
func (w *WGIface) configureDevice(config wgtypes.Config) error {
|
||||
wg, err := wgctrl.New()
|
||||
@@ -112,6 +122,114 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
||||
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP)
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
UpdateOnly: true,
|
||||
ReplaceAllowedIPs: false,
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
err = w.configureDevice(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("received error \"%v\" while adding allowed Ip to peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
||||
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP)
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingPeer, err := getPeer(w.Name, peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newAllowedIPs := existingPeer.AllowedIPs
|
||||
|
||||
for i, existingAllowedIP := range existingPeer.AllowedIPs {
|
||||
if existingAllowedIP.String() == ipNet.String() {
|
||||
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
UpdateOnly: true,
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: newAllowedIPs,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
err = w.configureDevice(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("received error \"%v\" while removing allowed IP from peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return wgtypes.Peer{}, err
|
||||
}
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got error while closing wgctl: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(ifaceName)
|
||||
if err != nil {
|
||||
return wgtypes.Peer{}, err
|
||||
}
|
||||
for _, peer := range wgDevice.Peers {
|
||||
if peer.PublicKey.String() == peerPubKey {
|
||||
return peer, nil
|
||||
}
|
||||
}
|
||||
return wgtypes.Peer{}, fmt.Errorf("peer not found")
|
||||
}
|
||||
|
||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
|
||||
@@ -34,7 +34,7 @@ func (w *WGIface) assignAddr() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
||||
func WireguardModExists() bool {
|
||||
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||
func WireguardModuleIsLoaded() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,48 +1,29 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"os"
|
||||
)
|
||||
|
||||
type NativeLink struct {
|
||||
Link *netlink.Link
|
||||
}
|
||||
|
||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
||||
func WireguardModExists() bool {
|
||||
link := newWGLink("mustnotexist")
|
||||
|
||||
// We willingly try to create a device with an invalid
|
||||
// MTU here as the validation of the MTU will be performed after
|
||||
// the validation of the link kind and hence allows us to check
|
||||
// for the existance of the wireguard module without actually
|
||||
// creating a link.
|
||||
//
|
||||
// As a side-effect, this will also let the kernel lazy-load
|
||||
// the wireguard module.
|
||||
link.attrs.MTU = math.MaxInt
|
||||
|
||||
err := netlink.LinkAdd(link)
|
||||
|
||||
return errors.Is(err, syscall.EINVAL)
|
||||
}
|
||||
|
||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
func (w *WGIface) Create() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if WireguardModExists() {
|
||||
if WireguardModuleIsLoaded() {
|
||||
log.Info("using kernel WireGuard")
|
||||
return w.createWithKernel()
|
||||
} else {
|
||||
if !tunModuleIsLoaded() {
|
||||
return fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
log.Info("using userspace WireGuard")
|
||||
return w.createWithUserspace()
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
peer, err := getPeer(ifaceName, peerPubKey, t)
|
||||
peer, err := getPeer(ifaceName, peerPubKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -289,7 +289,7 @@ func Test_RemovePeer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = getPeer(ifaceName, peerPubKey, t)
|
||||
_, err = getPeer(ifaceName, peerPubKey)
|
||||
if err.Error() != "peer not found" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -378,7 +378,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
t.Fatalf("waiting for peer handshake timeout after %s", timeout.String())
|
||||
default:
|
||||
}
|
||||
peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String(), t)
|
||||
peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String())
|
||||
if gpErr != nil {
|
||||
t.Fatal(gpErr)
|
||||
}
|
||||
@@ -389,28 +389,3 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func getPeer(ifaceName, peerPubKey string, t *testing.T) (wgtypes.Peer, error) {
|
||||
emptyPeer := wgtypes.Peer{}
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return emptyPeer, err
|
||||
}
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(ifaceName)
|
||||
if err != nil {
|
||||
return emptyPeer, err
|
||||
}
|
||||
for _, peer := range wgDevice.Peers {
|
||||
if peer.PublicKey.String() == peerPubKey {
|
||||
return peer, nil
|
||||
}
|
||||
}
|
||||
return emptyPeer, fmt.Errorf("peer not found")
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
return w.assignAddr(luid)
|
||||
}
|
||||
|
||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
||||
func WireguardModExists() bool {
|
||||
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||
func WireguardModuleIsLoaded() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
349
iface/module_linux.go
Normal file
349
iface/module_linux.go
Normal file
@@ -0,0 +1,349 @@
|
||||
// Package iface provides wireguard network interface creation and management
|
||||
package iface
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"io/fs"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Holds logic to check existence of kernel modules used by wireguard interfaces
|
||||
// Copied from https://github.com/paultag/go-modprobe and
|
||||
// https://github.com/pmorjan/kmod
|
||||
|
||||
type status int
|
||||
|
||||
const (
|
||||
defaultModuleDir = "/lib/modules"
|
||||
unknown status = iota
|
||||
unloaded
|
||||
unloading
|
||||
loading
|
||||
live
|
||||
inuse
|
||||
)
|
||||
|
||||
type module struct {
|
||||
name string
|
||||
path string
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrModuleNotFound is the error resulting if a module can't be found.
|
||||
ErrModuleNotFound = errors.New("module not found")
|
||||
moduleLibDir = defaultModuleDir
|
||||
// get the root directory for the kernel modules. If this line panics,
|
||||
// it's because getModuleRoot has failed to get the uname of the running
|
||||
// kernel (likely a non-POSIX system, but maybe a broken kernel?)
|
||||
moduleRoot = getModuleRoot()
|
||||
)
|
||||
|
||||
// Get the module root (/lib/modules/$(uname -r)/)
|
||||
func getModuleRoot() string {
|
||||
uname := unix.Utsname{}
|
||||
if err := unix.Uname(&uname); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
i := 0
|
||||
for ; uname.Release[i] != 0; i++ {
|
||||
}
|
||||
|
||||
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
|
||||
}
|
||||
|
||||
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
|
||||
func tunModuleIsLoaded() bool {
|
||||
_, err := os.Stat("/dev/net/tun")
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
log.Infof("couldn't access device /dev/net/tun, go error %v, "+
|
||||
"will attempt to load tun module, if running on container add flag --cap-add=NET_ADMIN", err)
|
||||
|
||||
tunLoaded, err := tryToLoadModule("tun")
|
||||
if err != nil {
|
||||
log.Errorf("unable to find or load tun module, got error: %v", err)
|
||||
}
|
||||
return tunLoaded
|
||||
}
|
||||
|
||||
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||
func WireguardModuleIsLoaded() bool {
|
||||
if canCreateFakeWireguardInterface() {
|
||||
return true
|
||||
}
|
||||
|
||||
loaded, err := tryToLoadModule("wireguard")
|
||||
if err != nil {
|
||||
log.Info(err)
|
||||
return false
|
||||
}
|
||||
|
||||
return loaded
|
||||
}
|
||||
|
||||
func canCreateFakeWireguardInterface() bool {
|
||||
link := newWGLink("mustnotexist")
|
||||
|
||||
// We willingly try to create a device with an invalid
|
||||
// MTU here as the validation of the MTU will be performed after
|
||||
// the validation of the link kind and hence allows us to check
|
||||
// for the existance of the wireguard module without actually
|
||||
// creating a link.
|
||||
//
|
||||
// As a side-effect, this will also let the kernel lazy-load
|
||||
// the wireguard module.
|
||||
link.attrs.MTU = math.MaxInt
|
||||
|
||||
err := netlink.LinkAdd(link)
|
||||
|
||||
return errors.Is(err, syscall.EINVAL)
|
||||
}
|
||||
|
||||
func tryToLoadModule(moduleName string) (bool, error) {
|
||||
if isModuleEnabled(moduleName) {
|
||||
return true, nil
|
||||
}
|
||||
modulePath, err := getModulePath(moduleName)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("couldn't find module path for %s, error: %v", moduleName, err)
|
||||
}
|
||||
if modulePath == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Infof("trying to load %s module", moduleName)
|
||||
|
||||
err = loadModuleWithDependencies(moduleName, modulePath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("couldn't load %s module, error: %v", moduleName, err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isModuleEnabled(name string) bool {
|
||||
builtin, builtinErr := isBuiltinModule(name)
|
||||
state, statusErr := moduleStatus(name)
|
||||
return (builtinErr == nil && builtin) || (statusErr == nil && state >= loading)
|
||||
}
|
||||
|
||||
func getModulePath(name string) (string, error) {
|
||||
var foundPath string
|
||||
skipRemainingDirs := false
|
||||
|
||||
err := filepath.WalkDir(
|
||||
moduleRoot,
|
||||
func(path string, info fs.DirEntry, err error) error {
|
||||
if skipRemainingDirs {
|
||||
return fs.SkipDir
|
||||
}
|
||||
if err != nil {
|
||||
// skip broken files
|
||||
return nil
|
||||
}
|
||||
|
||||
if !info.Type().IsRegular() {
|
||||
return nil
|
||||
}
|
||||
|
||||
nameFromPath := pathToName(path)
|
||||
if nameFromPath == name {
|
||||
foundPath = path
|
||||
skipRemainingDirs = true
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return foundPath, nil
|
||||
}
|
||||
|
||||
func pathToName(s string) string {
|
||||
s = filepath.Base(s)
|
||||
for ext := filepath.Ext(s); ext != ""; ext = filepath.Ext(s) {
|
||||
s = strings.TrimSuffix(s, ext)
|
||||
}
|
||||
return cleanName(s)
|
||||
}
|
||||
|
||||
func cleanName(s string) string {
|
||||
return strings.ReplaceAll(strings.TrimSpace(s), "-", "_")
|
||||
}
|
||||
|
||||
func isBuiltinModule(name string) (bool, error) {
|
||||
f, err := os.Open(filepath.Join(moduleRoot, "/modules.builtin"))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err := f.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed closing modules.builtin file, %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var found bool
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if pathToName(line) == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return found, nil
|
||||
}
|
||||
|
||||
// /proc/modules
|
||||
// name | memory size | reference count | references | state: <Live|Loading|Unloading>
|
||||
// macvlan 28672 1 macvtap, Live 0x0000000000000000
|
||||
func moduleStatus(name string) (status, error) {
|
||||
state := unknown
|
||||
f, err := os.Open("/proc/modules")
|
||||
if err != nil {
|
||||
return state, err
|
||||
}
|
||||
defer func() {
|
||||
err := f.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed closing /proc/modules file, %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
state = unloaded
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if fields[0] == name {
|
||||
if fields[2] != "0" {
|
||||
state = inuse
|
||||
break
|
||||
}
|
||||
switch fields[4] {
|
||||
case "Live":
|
||||
state = live
|
||||
case "Loading":
|
||||
state = loading
|
||||
case "Unloading":
|
||||
state = unloading
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return state, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func loadModuleWithDependencies(name, path string) error {
|
||||
deps, err := getModuleDependencies(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't load list of module %s dependecies", name)
|
||||
}
|
||||
for _, dep := range deps {
|
||||
err = loadModule(dep.name, dep.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't load dependecy module %s for %s", dep.name, name)
|
||||
}
|
||||
}
|
||||
return loadModule(name, path)
|
||||
}
|
||||
|
||||
func loadModule(name, path string) error {
|
||||
state, err := moduleStatus(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if state >= loading {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err := f.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed closing %s file, %v", path, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// first try finit_module(2), then init_module(2)
|
||||
err = unix.FinitModule(int(f.Fd()), "", 0)
|
||||
if errors.Is(err, unix.ENOSYS) {
|
||||
buf, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return unix.InitModule(buf, "")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// getModuleDependencies returns a module dependencies
|
||||
func getModuleDependencies(name string) ([]module, error) {
|
||||
f, err := os.Open(filepath.Join(moduleRoot, "/modules.dep"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err := f.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed closing modules.dep file, %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var deps []string
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fields := strings.Fields(line)
|
||||
if pathToName(strings.TrimSuffix(fields[0], ":")) == name {
|
||||
deps = fields
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(deps) == 0 {
|
||||
return nil, ErrModuleNotFound
|
||||
}
|
||||
deps[0] = strings.TrimSuffix(deps[0], ":")
|
||||
|
||||
var modules []module
|
||||
for _, v := range deps {
|
||||
if pathToName(v) != name {
|
||||
modules = append(modules, module{
|
||||
name: pathToName(v),
|
||||
path: filepath.Join(moduleRoot, v),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return modules, nil
|
||||
}
|
||||
221
iface/module_linux_test.go
Normal file
221
iface/module_linux_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetModuleDependencies(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
module string
|
||||
expected []module
|
||||
}{
|
||||
{
|
||||
name: "Get Single Dependency",
|
||||
module: "bar",
|
||||
expected: []module{
|
||||
{name: "foo", path: "kernel/a/foo.ko"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Get Multiple Dependencies",
|
||||
module: "baz",
|
||||
expected: []module{
|
||||
{name: "foo", path: "kernel/a/foo.ko"},
|
||||
{name: "bar", path: "kernel/a/bar.ko"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Get No Dependencies",
|
||||
module: "foo",
|
||||
expected: []module{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
defer resetGlobals()
|
||||
_, _ = createFiles(t)
|
||||
modules, err := getModuleDependencies(testCase.module)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := testCase.expected
|
||||
for i := range expected {
|
||||
expected[i].path = moduleRoot + "/" + expected[i].path
|
||||
}
|
||||
|
||||
require.ElementsMatchf(t, modules, expected, "returned modules should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBuiltinModule(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
module string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Built In Should Return True",
|
||||
module: "foo_bi",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Not Built In Should Return False",
|
||||
module: "not_built_in",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
defer resetGlobals()
|
||||
_, _ = createFiles(t)
|
||||
|
||||
isBuiltIn, err := isBuiltinModule(testCase.module)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testCase.expected, isBuiltIn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModuleStatus(t *testing.T) {
|
||||
random, err := getRandomLoadedModule(t)
|
||||
if err != nil {
|
||||
t.Fatal("should be able to get random module")
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
module string
|
||||
shouldBeLoaded bool
|
||||
}{
|
||||
{
|
||||
name: "Should Return Module Loading Or Greater Status",
|
||||
module: random,
|
||||
shouldBeLoaded: true,
|
||||
},
|
||||
{
|
||||
name: "Should Return Module Unloaded Or Lower Status",
|
||||
module: "not_loaded_module",
|
||||
shouldBeLoaded: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
defer resetGlobals()
|
||||
_, _ = createFiles(t)
|
||||
|
||||
state, err := moduleStatus(testCase.module)
|
||||
require.NoError(t, err)
|
||||
if testCase.shouldBeLoaded {
|
||||
require.GreaterOrEqual(t, loading, state, "moduleStatus for %s should return state loading", testCase.module)
|
||||
} else {
|
||||
require.Less(t, state, loading, "module should return state unloading or lower")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func resetGlobals() {
|
||||
moduleLibDir = defaultModuleDir
|
||||
moduleRoot = getModuleRoot()
|
||||
}
|
||||
|
||||
func createFiles(t *testing.T) (string, []module) {
|
||||
writeFile := func(path, text string) {
|
||||
if err := ioutil.WriteFile(path, []byte(text), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
var u unix.Utsname
|
||||
if err := unix.Uname(&u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
moduleLibDir = t.TempDir()
|
||||
|
||||
moduleRoot = getModuleRoot()
|
||||
if err := os.Mkdir(moduleRoot, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
text := "kernel/a/foo.ko:\n"
|
||||
text += "kernel/a/bar.ko: kernel/a/foo.ko\n"
|
||||
text += "kernel/a/baz.ko: kernel/a/bar.ko kernel/a/foo.ko\n"
|
||||
writeFile(filepath.Join(moduleRoot, "/modules.dep"), text)
|
||||
|
||||
text = "kernel/a/foo_bi.ko\n"
|
||||
text += "kernel/a/bar-bi.ko.gz\n"
|
||||
writeFile(filepath.Join(moduleRoot, "/modules.builtin"), text)
|
||||
|
||||
modules := []module{
|
||||
{name: "foo", path: "kernel/a/foo.ko"},
|
||||
{name: "bar", path: "kernel/a/bar.ko"},
|
||||
{name: "baz", path: "kernel/a/baz.ko"},
|
||||
}
|
||||
return moduleLibDir, modules
|
||||
}
|
||||
|
||||
func getRandomLoadedModule(t *testing.T) (string, error) {
|
||||
f, err := os.Open("/proc/modules")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err := f.Close()
|
||||
if err != nil {
|
||||
t.Logf("failed closing /proc/modules file, %v", err)
|
||||
}
|
||||
}()
|
||||
lines, err := lineCounter(f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
counter := 1
|
||||
midLine := lines / 2
|
||||
modName := ""
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if counter == midLine {
|
||||
if fields[4] == "Unloading" {
|
||||
continue
|
||||
}
|
||||
modName = fields[0]
|
||||
break
|
||||
}
|
||||
counter++
|
||||
}
|
||||
if scanner.Err() != nil {
|
||||
return "", scanner.Err()
|
||||
}
|
||||
return modName, nil
|
||||
}
|
||||
func lineCounter(r io.Reader) (int, error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
count := 0
|
||||
lineSep := []byte{'\n'}
|
||||
|
||||
for {
|
||||
c, err := r.Read(buf)
|
||||
count += bytes.Count(buf[:c], lineSep)
|
||||
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
return count, nil
|
||||
|
||||
case err != nil:
|
||||
return count, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,20 +31,23 @@ const (
|
||||
type AccountManager interface {
|
||||
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
|
||||
GetAccountByUser(userId string) (*Account, error)
|
||||
AddSetupKey(
|
||||
CreateSetupKey(
|
||||
accountId string,
|
||||
keyName string,
|
||||
keyType SetupKeyType,
|
||||
expiresIn time.Duration,
|
||||
autoGroups []string,
|
||||
) (*SetupKey, error)
|
||||
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
|
||||
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
|
||||
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
|
||||
SaveUser(accountID string, key *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, keyID string) (*SetupKey, error)
|
||||
GetAccountById(accountId string) (*Account, error)
|
||||
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
|
||||
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error)
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeer(peerKey string) (*Peer, error)
|
||||
GetPeers(accountID string) ([]*PeerInfo, error)
|
||||
MarkPeerConnected(peerKey string, connected bool) error
|
||||
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
|
||||
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
||||
@@ -55,7 +58,7 @@ type AccountManager interface {
|
||||
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error)
|
||||
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
|
||||
UpdatePeerSSHKey(peerKey string, sshKey string) error
|
||||
GetUsersFromAccount(accountId string) ([]*UserInfo, error)
|
||||
GetUsers(accountId string) ([]*UserInfo, error)
|
||||
GetGroup(accountId, groupID string) (*Group, error)
|
||||
SaveGroup(accountId string, group *Group) error
|
||||
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
||||
@@ -75,6 +78,7 @@ type AccountManager interface {
|
||||
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
|
||||
DeleteRoute(accountID, routeID string) error
|
||||
ListRoutes(accountID string) ([]*route.Route, error)
|
||||
ListSetupKeys(accountID string) ([]*SetupKey, error)
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@@ -105,10 +109,11 @@ type Account struct {
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
}
|
||||
|
||||
func (a *Account) Copy() *Account {
|
||||
@@ -244,93 +249,6 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
|
||||
func (am *DefaultAccountManager) AddSetupKey(
|
||||
accountId string,
|
||||
keyName string,
|
||||
keyType SetupKeyType,
|
||||
expiresIn time.Duration,
|
||||
) (*SetupKey, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
keyDuration := DefaultSetupKeyDuration
|
||||
if expiresIn != 0 {
|
||||
keyDuration = expiresIn
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountId)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
setupKey := GenerateSetupKey(keyName, keyType, keyDuration)
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
||||
}
|
||||
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
// RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore
|
||||
func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountId)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
setupKey := getAccountSetupKeyById(account, keyId)
|
||||
if setupKey == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
|
||||
}
|
||||
|
||||
keyCopy := setupKey.Copy()
|
||||
keyCopy.Revoked = true
|
||||
account.SetupKeys[keyCopy.Key] = keyCopy
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
||||
}
|
||||
|
||||
return keyCopy, nil
|
||||
}
|
||||
|
||||
// RenameSetupKey renames existing setup key of the specified account.
|
||||
func (am *DefaultAccountManager) RenameSetupKey(
|
||||
accountId string,
|
||||
keyId string,
|
||||
newName string,
|
||||
) (*SetupKey, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountId)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
setupKey := getAccountSetupKeyById(account, keyId)
|
||||
if setupKey == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
|
||||
}
|
||||
|
||||
keyCopy := setupKey.Copy()
|
||||
keyCopy.Name = newName
|
||||
account.SetupKeys[keyCopy.Key] = keyCopy
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
||||
}
|
||||
|
||||
return keyCopy, nil
|
||||
}
|
||||
|
||||
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
|
||||
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
|
||||
am.mux.Lock()
|
||||
@@ -385,19 +303,25 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo {
|
||||
return &UserInfo{
|
||||
ID: local.Id,
|
||||
Email: queried.Email,
|
||||
Name: queried.Name,
|
||||
Role: string(local.Role),
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
|
||||
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) lookupUserInCache(user *User, accountID string) (*idp.UserData, error) {
|
||||
userData, err := am.lookupCache(map[string]*User{user.Id: user}, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, datum := range userData {
|
||||
if datum.ID == user.Id {
|
||||
return datum, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", user.Id)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, accountID string) ([]*idp.UserData, error) {
|
||||
data, err := am.cacheManager.Get(am.ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -437,46 +361,6 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, acco
|
||||
return userData, err
|
||||
}
|
||||
|
||||
// GetUsersFromAccount performs a batched request for users from IDP by account id
|
||||
func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) {
|
||||
account, err := am.GetAccountById(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queriedUsers := make([]*idp.UserData, 0)
|
||||
if !isNil(am.idpManager) {
|
||||
queriedUsers, err = am.lookupCache(account.Users, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
userInfo := make([]*UserInfo, 0)
|
||||
|
||||
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
||||
if len(queriedUsers) == 0 {
|
||||
for _, user := range account.Users {
|
||||
userInfo = append(userInfo, &UserInfo{
|
||||
ID: user.Id,
|
||||
Email: "",
|
||||
Name: "",
|
||||
Role: string(user.Role),
|
||||
})
|
||||
}
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
for _, queriedUser := range queriedUsers {
|
||||
if localUser, contains := account.Users[queriedUser.ID]; contains {
|
||||
userInfo = append(userInfo, mergeLocalAndQueryUser(*queriedUser, *localUser))
|
||||
log.Debugf("Merged userinfo to send back; %v", userInfo)
|
||||
}
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
||||
account *Account,
|
||||
@@ -504,7 +388,6 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
||||
|
||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||
//
|
||||
//
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
@@ -688,7 +571,7 @@ func newAccountWithId(accountId, userId, domain string) *Account {
|
||||
|
||||
setupKeys := make(map[string]*SetupKey)
|
||||
defaultKey := GenerateDefaultSetupKey()
|
||||
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration)
|
||||
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{})
|
||||
setupKeys[defaultKey.Key] = defaultKey
|
||||
setupKeys[oneOffKey.Key] = oneOffKey
|
||||
network := NewNetwork()
|
||||
@@ -713,15 +596,6 @@ func newAccountWithId(accountId, userId, domain string) *Account {
|
||||
return acc
|
||||
}
|
||||
|
||||
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
|
||||
for _, k := range acc.SetupKeys {
|
||||
if keyId == k.Id {
|
||||
return k
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey {
|
||||
for _, k := range acc.SetupKeys {
|
||||
if key == k.Key {
|
||||
|
||||
@@ -847,7 +847,7 @@ func TestGetUsersFromAccount(t *testing.T) {
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
userInfos, err := manager.GetUsersFromAccount(accountId)
|
||||
userInfos, err := manager.GetUsers(accountId)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gRPCPeer "google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
@@ -79,7 +80,10 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
log.Debugf("Sync request from peer %s", req.WgPubKey)
|
||||
p, ok := gRPCPeer.FromContext(srv.Context())
|
||||
if ok {
|
||||
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
|
||||
}
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
@@ -255,7 +259,10 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
|
||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.Debugf("Login request from peer %s", req.WgPubKey)
|
||||
p, ok := gRPCPeer.FromContext(ctx)
|
||||
if ok {
|
||||
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
|
||||
}
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
|
||||
@@ -31,13 +31,45 @@ components:
|
||||
description: User's name from idp provider
|
||||
type: string
|
||||
role:
|
||||
description: User's Netbird account role
|
||||
description: User's NetBird account role
|
||||
type: string
|
||||
auto_groups:
|
||||
description: Groups to auto-assign to peers registered by this user
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
- name
|
||||
- role
|
||||
- auto_groups
|
||||
UserMinimum:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
description: User ID
|
||||
type: string
|
||||
email:
|
||||
description: User's email address
|
||||
type: string
|
||||
name:
|
||||
description: User's name from idp provider
|
||||
type: string
|
||||
UserRequest:
|
||||
type: object
|
||||
properties:
|
||||
role:
|
||||
description: User's NetBird account role
|
||||
type: string
|
||||
auto_groups:
|
||||
description: Groups to auto-assign to peers registered by this user
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required:
|
||||
- role
|
||||
- auto_groups
|
||||
PeerMinimum:
|
||||
type: object
|
||||
properties:
|
||||
@@ -90,6 +122,11 @@ components:
|
||||
ssh_enabled:
|
||||
description: Indicates whether SSH server is enabled on this peer
|
||||
type: boolean
|
||||
user:
|
||||
$ref: '#/components/schemas/UserMinimum'
|
||||
host_name:
|
||||
description: Peer's hostname
|
||||
type: string
|
||||
required:
|
||||
- ip
|
||||
- connected
|
||||
@@ -134,6 +171,15 @@ components:
|
||||
state:
|
||||
description: Setup key status, "valid", "overused","expired" or "revoked"
|
||||
type: string
|
||||
auto_groups:
|
||||
description: Setup key groups to auto-assign to peers registered with this key
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
updated_at:
|
||||
description: Setup key last update date
|
||||
type: string
|
||||
format: date-time
|
||||
required:
|
||||
- id
|
||||
- key
|
||||
@@ -145,6 +191,8 @@ components:
|
||||
- used_times
|
||||
- last_used
|
||||
- state
|
||||
- auto_groups
|
||||
- updated_at
|
||||
SetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
@@ -160,11 +208,17 @@ components:
|
||||
revoked:
|
||||
description: Setup key revocation status
|
||||
type: boolean
|
||||
auto_groups:
|
||||
description: Setup key groups to auto-assign to peers registered with this key
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- expires_in
|
||||
- revoked
|
||||
- auto_groups
|
||||
GroupMinimum:
|
||||
type: object
|
||||
properties:
|
||||
@@ -392,6 +446,40 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{id}:
|
||||
put:
|
||||
summary: Update information about a User
|
||||
tags: [ Users]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The User ID
|
||||
requestBody:
|
||||
description: User update
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A User object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/User'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/peers:
|
||||
get:
|
||||
summary: Returns a list of all peers
|
||||
|
||||
@@ -137,6 +137,9 @@ type Peer struct {
|
||||
// Groups that the peer belongs to
|
||||
Groups []GroupMinimum `json:"groups"`
|
||||
|
||||
// Peer's hostname
|
||||
HostName *string `json:"host_name,omitempty"`
|
||||
|
||||
// Peer ID
|
||||
Id string `json:"id"`
|
||||
|
||||
@@ -153,7 +156,8 @@ type Peer struct {
|
||||
Os string `json:"os"`
|
||||
|
||||
// Indicates whether SSH server is enabled on this peer
|
||||
SshEnabled bool `json:"ssh_enabled"`
|
||||
SshEnabled bool `json:"ssh_enabled"`
|
||||
User *UserMinimum `json:"user,omitempty"`
|
||||
|
||||
// Peer's daemon or cli version
|
||||
Version string `json:"version"`
|
||||
@@ -299,6 +303,9 @@ type RulePatchOperationPath string
|
||||
|
||||
// SetupKey defines model for SetupKey.
|
||||
type SetupKey struct {
|
||||
// Setup key groups to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
@@ -323,6 +330,9 @@ type SetupKey struct {
|
||||
// Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
@@ -332,6 +342,9 @@ type SetupKey struct {
|
||||
|
||||
// SetupKeyRequest defines model for SetupKeyRequest.
|
||||
type SetupKeyRequest struct {
|
||||
// Setup key groups to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Expiration time in seconds
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
@@ -347,6 +360,9 @@ type SetupKeyRequest struct {
|
||||
|
||||
// User defines model for User.
|
||||
type User struct {
|
||||
// Groups to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// User's email address
|
||||
Email string `json:"email"`
|
||||
|
||||
@@ -356,7 +372,28 @@ type User struct {
|
||||
// User's name from idp provider
|
||||
Name string `json:"name"`
|
||||
|
||||
// User's Netbird account role
|
||||
// User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// UserMinimum defines model for UserMinimum.
|
||||
type UserMinimum struct {
|
||||
// User's email address
|
||||
Email *string `json:"email,omitempty"`
|
||||
|
||||
// User ID
|
||||
Id *string `json:"id,omitempty"`
|
||||
|
||||
// User's name from idp provider
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// UserRequest defines model for UserRequest.
|
||||
type UserRequest struct {
|
||||
// Groups to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
@@ -433,6 +470,9 @@ type PostApiSetupKeysJSONBody = SetupKeyRequest
|
||||
// PutApiSetupKeysIdJSONBody defines parameters for PutApiSetupKeysId.
|
||||
type PutApiSetupKeysIdJSONBody = SetupKeyRequest
|
||||
|
||||
// PutApiUsersIdJSONBody defines parameters for PutApiUsersId.
|
||||
type PutApiUsersIdJSONBody = UserRequest
|
||||
|
||||
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
|
||||
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
|
||||
|
||||
@@ -468,3 +508,6 @@ type PostApiSetupKeysJSONRequestBody = PostApiSetupKeysJSONBody
|
||||
|
||||
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
|
||||
type PutApiSetupKeysIdJSONRequestBody = PutApiSetupKeysIdJSONBody
|
||||
|
||||
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
|
||||
type PutApiUsersIdJSONRequestBody = PutApiUsersIdJSONBody
|
||||
|
||||
@@ -39,12 +39,12 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
|
||||
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "POST", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "PUT", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||
|
||||
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("POST", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
|
||||
|
||||
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
//Peers is a handler that returns peers of the account
|
||||
// Peers is a handler that returns peers of the account
|
||||
type Peers struct {
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
@@ -42,7 +42,7 @@ func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.Re
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
writeJSONObject(w, toPeerResponse(peer, account))
|
||||
writeJSONObject(w, toPeerResponse(&server.PeerInfo{Peer: peer}, account))
|
||||
}
|
||||
|
||||
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
||||
@@ -83,7 +83,7 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
h.updatePeer(account, peer, w, r)
|
||||
return
|
||||
case http.MethodGet:
|
||||
writeJSONObject(w, toPeerResponse(peer, account))
|
||||
writeJSONObject(w, toPeerResponse(&server.PeerInfo{Peer: peer}, account))
|
||||
return
|
||||
|
||||
default:
|
||||
@@ -93,27 +93,34 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
respBody := []*api.Peer{}
|
||||
for _, peer := range account.Peers {
|
||||
respBody = append(respBody, toPeerResponse(peer, account))
|
||||
}
|
||||
writeJSONObject(w, respBody)
|
||||
return
|
||||
default:
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := h.accountManager.GetPeers(account.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
respBody := []*api.Peer{}
|
||||
for _, peer := range peers {
|
||||
respBody = append(respBody, toPeerResponse(peer, account))
|
||||
}
|
||||
writeJSONObject(w, respBody)
|
||||
return
|
||||
}
|
||||
|
||||
func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
|
||||
func toPeerResponse(peer *server.PeerInfo, account *server.Account) *api.Peer {
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsChecked := make(map[string]struct{})
|
||||
for _, group := range account.Groups {
|
||||
@@ -123,7 +130,7 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
|
||||
}
|
||||
groupsChecked[group.ID] = struct{}{}
|
||||
for _, pk := range group.Peers {
|
||||
if pk == peer.Key {
|
||||
if pk == peer.Peer.Key {
|
||||
info := api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
@@ -134,15 +141,26 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
|
||||
}
|
||||
}
|
||||
}
|
||||
return &api.Peer{
|
||||
Id: peer.IP.String(),
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, peer.Meta.Core),
|
||||
Version: peer.Meta.WtVersion,
|
||||
resp := &api.Peer{
|
||||
Id: peer.Peer.IP.String(),
|
||||
Name: peer.Peer.Name,
|
||||
Ip: peer.Peer.IP.String(),
|
||||
Connected: peer.Peer.Status.Connected,
|
||||
LastSeen: peer.Peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Peer.Meta.OS, peer.Peer.Meta.Core),
|
||||
Version: peer.Peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
SshEnabled: peer.Peer.SSHEnabled,
|
||||
HostName: &peer.Peer.Meta.Hostname,
|
||||
}
|
||||
|
||||
if peer.UserInfo != nil {
|
||||
resp.User = &api.UserMinimum{
|
||||
Email: &peer.UserInfo.Email,
|
||||
Id: &peer.UserInfo.ID,
|
||||
Name: &peer.UserInfo.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
@@ -348,6 +348,11 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = h.accountManager.DeleteRoute(account.Id, routeID)
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -78,7 +78,10 @@ func initRoutesTestData() *Routes {
|
||||
SaveRouteFunc: func(_ string, _ *route.Route) error {
|
||||
return nil
|
||||
},
|
||||
DeleteRouteFunc: func(_ string, _ string) error {
|
||||
DeleteRouteFunc: func(_ string, peerIP string) error {
|
||||
if peerIP != existingRouteID {
|
||||
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
||||
@@ -155,7 +158,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
{
|
||||
name: "Get Not Existing Route",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/rules/" + notFoundRouteID,
|
||||
requestPath: "/api/routes/" + notFoundRouteID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
@@ -168,7 +171,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
{
|
||||
name: "Delete Not Existing Route",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/rules/" + notFoundRouteID,
|
||||
requestPath: "/api/routes/" + notFoundRouteID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
@@ -28,54 +29,17 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PutApiSetupKeysIdJSONRequestBody{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
|
||||
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var key *server.SetupKey
|
||||
if req.Revoked {
|
||||
//handle only if being revoked, don't allow to enable key again for now
|
||||
key, err = h.accountManager.RevokeSetupKey(accountId, keyId)
|
||||
if err != nil {
|
||||
http.Error(w, "failed revoking key", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(req.Name) != 0 {
|
||||
key, err = h.accountManager.RenameSetupKey(accountId, keyId, req.Name)
|
||||
if err != nil {
|
||||
http.Error(w, "failed renaming key", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if key != nil {
|
||||
writeSuccess(w, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
|
||||
account, err := h.accountManager.GetAccountById(accountId)
|
||||
if err != nil {
|
||||
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Id == keyId {
|
||||
writeSuccess(w, key)
|
||||
return
|
||||
}
|
||||
}
|
||||
http.Error(w, "setup key not found", http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PostApiSetupKeysJSONRequestBody{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
@@ -95,7 +59,13 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
|
||||
|
||||
expiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||
|
||||
setupKey, err := h.accountManager.AddSetupKey(accountId, req.Name, server.SetupKeyType(req.Type), expiresIn)
|
||||
if req.AutoGroups == nil {
|
||||
req.AutoGroups = []string{}
|
||||
}
|
||||
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
|
||||
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups)
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
@@ -109,7 +79,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
|
||||
writeSuccess(w, setupKey)
|
||||
}
|
||||
|
||||
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
||||
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
|
||||
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
@@ -118,25 +89,84 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyId := vars["id"]
|
||||
if len(keyId) == 0 {
|
||||
keyID := vars["id"]
|
||||
if len(keyID) == 0 {
|
||||
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPut:
|
||||
h.updateKey(account.Id, keyId, w, r)
|
||||
key, err := h.accountManager.GetSetupKey(account.Id, keyID)
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
case http.MethodGet:
|
||||
h.getKey(account.Id, keyId, w, r)
|
||||
return
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
|
||||
writeSuccess(w, key)
|
||||
}
|
||||
|
||||
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
|
||||
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
|
||||
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["id"]
|
||||
if len(keyID) == 0 {
|
||||
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PutApiSetupKeysIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newKey := &server.SetupKey{}
|
||||
newKey.AutoGroups = req.AutoGroups
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Name = req.Name
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
|
||||
|
||||
if err != nil {
|
||||
if e, ok := status.FromError(err); ok {
|
||||
switch e.Code() {
|
||||
case codes.NotFound:
|
||||
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
|
||||
default:
|
||||
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
writeSuccess(w, newKey)
|
||||
}
|
||||
|
||||
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
|
||||
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
@@ -145,28 +175,18 @@ func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
h.createKey(account.Id, w, r)
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(account.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
case http.MethodGet:
|
||||
w.WriteHeader(200)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
respBody := []*api.SetupKey{}
|
||||
for _, key := range account.SetupKeys {
|
||||
respBody = append(respBody, toResponseBody(key))
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(respBody)
|
||||
if err != nil {
|
||||
log.Errorf("failed encoding account peers %s: %v", account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
apiSetupKeys := make([]*api.SetupKey, 0)
|
||||
for _, key := range setupKeys {
|
||||
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
|
||||
}
|
||||
|
||||
writeJSONObject(w, apiSetupKeys)
|
||||
}
|
||||
|
||||
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
||||
@@ -190,16 +210,19 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey {
|
||||
} else {
|
||||
state = "valid"
|
||||
}
|
||||
|
||||
return &api.SetupKey{
|
||||
Id: key.Id,
|
||||
Key: key.Key,
|
||||
Name: key.Name,
|
||||
Expires: key.ExpiresAt,
|
||||
Type: string(key.Type),
|
||||
Valid: key.IsValid(),
|
||||
Revoked: key.Revoked,
|
||||
UsedTimes: key.UsedTimes,
|
||||
LastUsed: key.LastUsed,
|
||||
State: state,
|
||||
Id: key.Id,
|
||||
Key: key.Key,
|
||||
Name: key.Name,
|
||||
Expires: key.ExpiresAt,
|
||||
Type: string(key.Type),
|
||||
Valid: key.IsValid(),
|
||||
Revoked: key.Revoked,
|
||||
UsedTimes: key.UsedTimes,
|
||||
LastUsed: key.LastUsed,
|
||||
State: state,
|
||||
AutoGroups: key.AutoGroups,
|
||||
UpdatedAt: key.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
222
management/server/http/setupkeys_test.go
Normal file
222
management/server/http/setupkeys_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
const (
|
||||
existingSetupKeyID = "existingSetupKeyID"
|
||||
newSetupKeyName = "New Setup Key"
|
||||
updatedSetupKeyName = "KKKey"
|
||||
notFoundSetupKeyID = "notFoundSetupKeyID"
|
||||
)
|
||||
|
||||
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
|
||||
return &SetupKeys{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return &server.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
SetupKeys: map[string]*server.SetupKey{
|
||||
defaultKey.Key: defaultKey,
|
||||
},
|
||||
Groups: map[string]*server.Group{
|
||||
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"}},
|
||||
}, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
|
||||
if keyName == newKey.Name || typ != newKey.Type {
|
||||
return newKey, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed creating setup key")
|
||||
},
|
||||
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) {
|
||||
switch keyID {
|
||||
case defaultKey.Id:
|
||||
return defaultKey, nil
|
||||
case newKey.Id:
|
||||
return newKey, nil
|
||||
default:
|
||||
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID)
|
||||
}
|
||||
},
|
||||
|
||||
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
|
||||
if key.Id == updatedSetupKey.Id {
|
||||
return updatedSetupKey, nil
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
|
||||
},
|
||||
|
||||
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) {
|
||||
return []*server.SetupKey{defaultKey}, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupKeysHandlers(t *testing.T) {
|
||||
defaultSetupKey := server.GenerateDefaultSetupKey()
|
||||
defaultSetupKey.Id = existingSetupKeyID
|
||||
|
||||
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
|
||||
updatedDefaultSetupKey := defaultSetupKey.Copy()
|
||||
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
|
||||
updatedDefaultSetupKey.Name = updatedSetupKeyName
|
||||
updatedDefaultSetupKey.Revoked = true
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
expectedSetupKey *api.SetupKey
|
||||
expectedSetupKeys []*api.SetupKey
|
||||
}{
|
||||
{
|
||||
name: "Get Setup Keys",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/setup-keys",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)},
|
||||
},
|
||||
{
|
||||
name: "Get Existing Setup Key",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/setup-keys/" + existingSetupKeyID,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedSetupKey: toResponseBody(defaultSetupKey),
|
||||
},
|
||||
{
|
||||
name: "Get Not Existing Setup Key",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/setup-keys/" + notFoundSetupKeyID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "Create Setup Key",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/setup-keys",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\"}", newSetupKey.Name, newSetupKey.Type))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedSetupKey: toResponseBody(newSetupKey),
|
||||
},
|
||||
{
|
||||
name: "Update Setup Key",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/setup-keys/" + defaultSetupKey.Id,
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"auto_groups\":[\"%s\"], \"revoked\":%v}",
|
||||
updatedDefaultSetupKey.Type,
|
||||
updatedDefaultSetupKey.AutoGroups[0],
|
||||
updatedDefaultSetupKey.Revoked,
|
||||
))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
|
||||
},
|
||||
}
|
||||
|
||||
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", handler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{id}", handler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{id}", handler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
if !tc.expectedBody {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expectedSetupKey != nil {
|
||||
got := &api.SetupKey{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assertKeys(t, got, tc.expectedSetupKey)
|
||||
return
|
||||
}
|
||||
|
||||
if len(tc.expectedSetupKeys) > 0 {
|
||||
var got []*api.SetupKey
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assertKeys(t, got[0], tc.expectedSetupKeys[0])
|
||||
return
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
|
||||
// this comparison is done manually because when converting to JSON dates formatted differently
|
||||
// assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work
|
||||
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
|
||||
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
|
||||
assert.Equal(t, got.Name, expected.Name)
|
||||
assert.Equal(t, got.Id, expected.Id)
|
||||
assert.Equal(t, got.Key, expected.Key)
|
||||
assert.Equal(t, got.Type, expected.Type)
|
||||
assert.Equal(t, got.UsedTimes, expected.UsedTimes)
|
||||
assert.Equal(t, got.Revoked, expected.Revoked)
|
||||
assert.ElementsMatch(t, got.AutoGroups, expected.AutoGroups)
|
||||
}
|
||||
@@ -1,7 +1,12 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -24,6 +29,59 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateUser is a PUT requests to update User data
|
||||
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["id"]
|
||||
if len(userID) == 0 {
|
||||
http.Error(w, "invalid user ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PutApiUsersIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userRole := server.StrRoleToUserRole(req.Role)
|
||||
if userRole == server.UserRoleUnknown {
|
||||
http.Error(w, "invalid user role", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(account.Id, &server.User{
|
||||
Id: userID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
})
|
||||
if err != nil {
|
||||
if e, ok := status.FromError(err); ok {
|
||||
switch e.Code() {
|
||||
case codes.NotFound:
|
||||
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
|
||||
default:
|
||||
http.Error(w, "failed to update user", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
writeJSONObject(w, toUserResponse(newUser))
|
||||
|
||||
}
|
||||
|
||||
// GetUsers returns a list of users of the account this user belongs to.
|
||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -34,9 +92,11 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.accountManager.GetUsersFromAccount(account.Id)
|
||||
data, err := h.accountManager.GetUsers(account.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
@@ -52,10 +112,17 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func toUserResponse(user *server.UserInfo) *api.User {
|
||||
|
||||
autoGroups := user.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,8 @@ import (
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||
AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration) (*server.SetupKey, error)
|
||||
RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error)
|
||||
RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error)
|
||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByIdFunc func(accountId string) (*server.Account, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
|
||||
@@ -51,6 +50,9 @@ type MockAccountManager struct {
|
||||
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
|
||||
DeleteRouteFunc func(accountID, routeID string) error
|
||||
ListRoutesFunc func(accountID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
@@ -58,7 +60,7 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.U
|
||||
if am.GetUsersFromAccountFunc != nil {
|
||||
return am.GetUsersFromAccountFunc(accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUsers is not implemented")
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||
@@ -82,40 +84,18 @@ func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account,
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
|
||||
}
|
||||
|
||||
// AddSetupKey mock implementation of AddSetupKey from server.AccountManager interface
|
||||
func (am *MockAccountManager) AddSetupKey(
|
||||
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreateSetupKey(
|
||||
accountId string,
|
||||
keyName string,
|
||||
keyType server.SetupKeyType,
|
||||
expiresIn time.Duration,
|
||||
autoGroups []string,
|
||||
) (*server.SetupKey, error) {
|
||||
if am.AddSetupKeyFunc != nil {
|
||||
return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn)
|
||||
if am.CreateSetupKeyFunc != nil {
|
||||
return am.CreateSetupKeyFunc(accountId, keyName, keyType, expiresIn, autoGroups)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// RevokeSetupKey mock implementation of RevokeSetupKey from server.AccountManager interface
|
||||
func (am *MockAccountManager) RevokeSetupKey(
|
||||
accountId string,
|
||||
keyId string,
|
||||
) (*server.SetupKey, error) {
|
||||
if am.RevokeSetupKeyFunc != nil {
|
||||
return am.RevokeSetupKeyFunc(accountId, keyId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// RenameSetupKey mock implementation of RenameSetupKey from server.AccountManager interface
|
||||
func (am *MockAccountManager) RenameSetupKey(
|
||||
accountId string,
|
||||
keyId string,
|
||||
newName string,
|
||||
) (*server.SetupKey, error) {
|
||||
if am.RenameSetupKeyFunc != nil {
|
||||
return am.RenameSetupKeyFunc(accountId, keyId, newName)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
|
||||
@@ -415,3 +395,38 @@ func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, erro
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
|
||||
}
|
||||
|
||||
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
|
||||
if am.SaveSetupKeyFunc != nil {
|
||||
return am.SaveSetupKeyFunc(accountID, key)
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// GetSetupKey mocks GetSetupKey of the AccountManager interface
|
||||
func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) {
|
||||
if am.GetSetupKeyFunc != nil {
|
||||
return am.GetSetupKeyFunc(accountID, keyID)
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
|
||||
func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) {
|
||||
if am.ListSetupKeysFunc != nil {
|
||||
return am.ListSetupKeysFunc(accountID)
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
|
||||
}
|
||||
|
||||
// SaveUser mocks SaveUser of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveUser(accountID string, user *server.User) (*server.UserInfo, error) {
|
||||
if am.SaveUserFunc != nil {
|
||||
return am.SaveUserFunc(accountID, user)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
|
||||
}
|
||||
|
||||
@@ -31,6 +31,12 @@ type PeerStatus struct {
|
||||
Connected bool
|
||||
}
|
||||
|
||||
// PeerInfo is a composition of Peer and additional UserInfo
|
||||
type PeerInfo struct {
|
||||
Peer *Peer
|
||||
UserInfo *UserInfo
|
||||
}
|
||||
|
||||
// Peer represents a machine connected to the network.
|
||||
// The Peer is a Wireguard peer identified by a public key
|
||||
type Peer struct {
|
||||
@@ -68,6 +74,44 @@ func (p *Peer) Copy() *Peer {
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeers returns a list of Peers belonging to the specified account
|
||||
func (am *DefaultAccountManager) GetPeers(accountID string) ([]*PeerInfo, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
users, err := am.getUsersInfos(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userMap = make(map[string]*UserInfo)
|
||||
for _, user := range users {
|
||||
userMap[user.ID] = user
|
||||
}
|
||||
|
||||
var peerInfos []*PeerInfo
|
||||
for _, peer := range account.Peers {
|
||||
if peer.UserID == "" {
|
||||
peerInfos = append(peerInfos, &PeerInfo{
|
||||
Peer: peer,
|
||||
UserInfo: nil,
|
||||
})
|
||||
} else {
|
||||
peerInfos = append(peerInfos, &PeerInfo{
|
||||
Peer: peer.Copy(),
|
||||
UserInfo: userMap[peer.UserID],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return peerInfos, nil
|
||||
}
|
||||
|
||||
// GetPeer returns a peer from a Store
|
||||
func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
|
||||
am.mux.Lock()
|
||||
@@ -294,6 +338,8 @@ func (am *DefaultAccountManager) AddPeer(
|
||||
var account *Account
|
||||
var err error
|
||||
var sk *SetupKey
|
||||
// auto-assign groups that are coming with a SetupKey or a User
|
||||
var groupsToAdd []string
|
||||
if len(upperKey) != 0 {
|
||||
account, err = am.Store.GetAccountBySetupKey(upperKey)
|
||||
if err != nil {
|
||||
@@ -321,11 +367,20 @@ func (am *DefaultAccountManager) AddPeer(
|
||||
)
|
||||
}
|
||||
|
||||
groupsToAdd = sk.AutoGroups
|
||||
|
||||
} else if len(userID) != 0 {
|
||||
account, err = am.Store.GetUserAccount(userID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
|
||||
}
|
||||
user, ok := account.Users[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
|
||||
}
|
||||
|
||||
groupsToAdd = user.AutoGroups
|
||||
|
||||
} else {
|
||||
// Empty setup key and jwt fail
|
||||
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
|
||||
@@ -361,6 +416,14 @@ func (am *DefaultAccountManager) AddPeer(
|
||||
}
|
||||
group.Peers = append(group.Peers, newPeer.Key)
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, s := range groupsToAdd {
|
||||
if g, ok := account.Groups[s]; ok && g.Name != "All" {
|
||||
g.Peers = append(g.Peers, newPeer.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
account.Peers[newPeer.Key] = newPeer
|
||||
if len(upperKey) != 0 {
|
||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"hash/fnv"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -18,8 +21,41 @@ const (
|
||||
DefaultSetupKeyDuration = 24 * 30 * time.Hour
|
||||
// DefaultSetupKeyName is a default name of the default setup key
|
||||
DefaultSetupKeyName = "Default key"
|
||||
|
||||
// UpdateSetupKeyName indicates a setup key name update operation
|
||||
UpdateSetupKeyName SetupKeyUpdateOperationType = iota
|
||||
// UpdateSetupKeyRevoked indicates a setup key revoked filed update operation
|
||||
UpdateSetupKeyRevoked
|
||||
// UpdateSetupKeyAutoGroups indicates a setup key auto-assign groups update operation
|
||||
UpdateSetupKeyAutoGroups
|
||||
// UpdateSetupKeyExpiresAt indicates a setup key expiration time update operation
|
||||
UpdateSetupKeyExpiresAt
|
||||
)
|
||||
|
||||
// SetupKeyUpdateOperationType operation type
|
||||
type SetupKeyUpdateOperationType int
|
||||
|
||||
func (t SetupKeyUpdateOperationType) String() string {
|
||||
switch t {
|
||||
case UpdateSetupKeyName:
|
||||
return "UpdateSetupKeyName"
|
||||
case UpdateSetupKeyRevoked:
|
||||
return "UpdateSetupKeyRevoked"
|
||||
case UpdateSetupKeyAutoGroups:
|
||||
return "UpdateSetupKeyAutoGroups"
|
||||
case UpdateSetupKeyExpiresAt:
|
||||
return "UpdateSetupKeyExpiresAt"
|
||||
default:
|
||||
return "InvalidOperation"
|
||||
}
|
||||
}
|
||||
|
||||
// SetupKeyUpdateOperation operation object with type and values to be applied
|
||||
type SetupKeyUpdateOperation struct {
|
||||
Type SetupKeyUpdateOperationType
|
||||
Values []string
|
||||
}
|
||||
|
||||
// SetupKeyType is the type of setup key
|
||||
type SetupKeyType string
|
||||
|
||||
@@ -31,30 +67,40 @@ type SetupKey struct {
|
||||
Type SetupKeyType
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
UpdatedAt time.Time
|
||||
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
|
||||
Revoked bool
|
||||
// UsedTimes indicates how many times the key was used
|
||||
UsedTimes int
|
||||
// LastUsed last time the key was used for peer registration
|
||||
LastUsed time.Time
|
||||
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
|
||||
AutoGroups []string
|
||||
}
|
||||
|
||||
//Copy copies SetupKey to a new object
|
||||
// Copy copies SetupKey to a new object
|
||||
func (key *SetupKey) Copy() *SetupKey {
|
||||
autoGroups := make([]string, 0)
|
||||
autoGroups = append(autoGroups, key.AutoGroups...)
|
||||
if key.UpdatedAt.IsZero() {
|
||||
key.UpdatedAt = key.CreatedAt
|
||||
}
|
||||
return &SetupKey{
|
||||
Id: key.Id,
|
||||
Key: key.Key,
|
||||
Name: key.Name,
|
||||
Type: key.Type,
|
||||
CreatedAt: key.CreatedAt,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
Revoked: key.Revoked,
|
||||
UsedTimes: key.UsedTimes,
|
||||
LastUsed: key.LastUsed,
|
||||
Id: key.Id,
|
||||
Key: key.Key,
|
||||
Name: key.Name,
|
||||
Type: key.Type,
|
||||
CreatedAt: key.CreatedAt,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
UpdatedAt: key.UpdatedAt,
|
||||
Revoked: key.Revoked,
|
||||
UsedTimes: key.UsedTimes,
|
||||
LastUsed: key.LastUsed,
|
||||
AutoGroups: autoGroups,
|
||||
}
|
||||
}
|
||||
|
||||
//IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
|
||||
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
|
||||
func (key *SetupKey) IncrementUsage() *SetupKey {
|
||||
c := key.Copy()
|
||||
c.UsedTimes = c.UsedTimes + 1
|
||||
@@ -83,24 +129,25 @@ func (key *SetupKey) IsOverUsed() bool {
|
||||
}
|
||||
|
||||
// GenerateSetupKey generates a new setup key
|
||||
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration) *SetupKey {
|
||||
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string) *SetupKey {
|
||||
key := strings.ToUpper(uuid.New().String())
|
||||
createdAt := time.Now()
|
||||
return &SetupKey{
|
||||
Id: strconv.Itoa(int(Hash(key))),
|
||||
Key: key,
|
||||
Name: name,
|
||||
Type: t,
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: createdAt.Add(validFor),
|
||||
Revoked: false,
|
||||
UsedTimes: 0,
|
||||
Id: strconv.Itoa(int(Hash(key))),
|
||||
Key: key,
|
||||
Name: name,
|
||||
Type: t,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(validFor),
|
||||
UpdatedAt: time.Now(),
|
||||
Revoked: false,
|
||||
UsedTimes: 0,
|
||||
AutoGroups: autoGroups,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateDefaultSetupKey generates a default setup key
|
||||
func GenerateDefaultSetupKey() *SetupKey {
|
||||
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration)
|
||||
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{})
|
||||
}
|
||||
|
||||
func Hash(s string) uint32 {
|
||||
@@ -111,3 +158,127 @@ func Hash(s string) uint32 {
|
||||
}
|
||||
return h.Sum32()
|
||||
}
|
||||
|
||||
// CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key,
|
||||
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string) (*SetupKey, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
keyDuration := DefaultSetupKeyDuration
|
||||
if expiresIn != 0 {
|
||||
keyDuration = expiresIn
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
for _, group := range autoGroups {
|
||||
if _, ok := account.Groups[group]; !ok {
|
||||
return nil, fmt.Errorf("group %s doesn't exist", group)
|
||||
}
|
||||
}
|
||||
|
||||
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups)
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
||||
}
|
||||
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
|
||||
// Due to the unique nature of a SetupKey certain properties must not be overwritten
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
if keyToSave == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
var oldKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Id == keyToSave.Id {
|
||||
oldKey = key.Copy()
|
||||
break
|
||||
}
|
||||
}
|
||||
if oldKey == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newKey := oldKey.Copy()
|
||||
newKey.Name = keyToSave.Name
|
||||
newKey.AutoGroups = keyToSave.AutoGroups
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now()
|
||||
|
||||
account.SetupKeys[newKey.Key] = newKey
|
||||
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newKey, am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
// ListSetupKeys returns a list of all setup keys of the account
|
||||
func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
keys := make([]*SetupKey, 0, len(account.SetupKeys))
|
||||
for _, key := range account.SetupKeys {
|
||||
keys = append(keys, key.Copy())
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
var foundKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Id == keyID {
|
||||
foundKey = key.Copy()
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundKey == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
||||
if foundKey.UpdatedAt.IsZero() {
|
||||
foundKey.UpdatedAt = foundKey.CreatedAt
|
||||
}
|
||||
|
||||
return foundKey, nil
|
||||
}
|
||||
|
||||
@@ -2,23 +2,159 @@ package server
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
userID := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(account.Id, &Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expiresIn := time.Hour
|
||||
keyName := "my-test-key"
|
||||
|
||||
key, err := manager.CreateSetupKey(account.Id, keyName, SetupKeyReusable, expiresIn, []string{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
autoGroups := []string{"group_1", "group_2"}
|
||||
newKeyName := "my-new-test-key"
|
||||
revoked := true
|
||||
newKey, err := manager.SaveSetupKey(account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
key.Id, time.Now(), autoGroups)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
userID := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(account.Id, &Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(account.Id, &Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
|
||||
expectedKeyName string
|
||||
expectedUsedTimes int
|
||||
expectedType string
|
||||
expectedGroups []string
|
||||
expectedCreatedAt time.Time
|
||||
expectedUpdatedAt time.Time
|
||||
expectedExpiresAt time.Time
|
||||
expectedFailure bool //indicates whether key creation should fail
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expiresIn := time.Hour
|
||||
testCase1 := testCase{
|
||||
name: "Should Create Setup Key successfully",
|
||||
expectedKeyName: "my-test-key",
|
||||
expectedUsedTimes: 0,
|
||||
expectedType: "reusable",
|
||||
expectedGroups: []string{"group_1", "group_2"},
|
||||
expectedCreatedAt: now,
|
||||
expectedUpdatedAt: now,
|
||||
expectedExpiresAt: now.Add(expiresIn),
|
||||
expectedFailure: false,
|
||||
}
|
||||
testCase2 := testCase{
|
||||
name: "Create Setup Key should fail because of unexistent group",
|
||||
expectedKeyName: "my-test-key",
|
||||
expectedGroups: []string{"FAKE"},
|
||||
expectedFailure: true,
|
||||
}
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.CreateSetupKey(account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||
tCase.expectedGroups)
|
||||
|
||||
if tCase.expectedFailure {
|
||||
if err == nil {
|
||||
t.Fatal("expected to fail")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
|
||||
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))),
|
||||
tCase.expectedUpdatedAt, tCase.expectedGroups)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGenerateDefaultSetupKey(t *testing.T) {
|
||||
expectedName := "Default key"
|
||||
expectedRevoke := false
|
||||
expectedType := "reusable"
|
||||
expectedUsedTimes := 0
|
||||
expectedCreatedAt := time.Now()
|
||||
expectedUpdatedAt := time.Now()
|
||||
expectedExpiresAt := time.Now().Add(24 * 30 * time.Hour)
|
||||
var expectedAutoGroups []string
|
||||
|
||||
key := GenerateDefaultSetupKey()
|
||||
|
||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
|
||||
|
||||
}
|
||||
|
||||
@@ -29,41 +165,44 @@ func TestGenerateSetupKey(t *testing.T) {
|
||||
expectedUsedTimes := 0
|
||||
expectedCreatedAt := time.Now()
|
||||
expectedExpiresAt := time.Now().Add(time.Hour)
|
||||
expectedUpdatedAt := time.Now()
|
||||
var expectedAutoGroups []string
|
||||
|
||||
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour)
|
||||
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{})
|
||||
|
||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
|
||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
|
||||
|
||||
}
|
||||
|
||||
func TestSetupKey_IsValid(t *testing.T) {
|
||||
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour)
|
||||
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{})
|
||||
if !validKey.IsValid() {
|
||||
t.Errorf("expected key to be valid, got invalid %v", validKey)
|
||||
}
|
||||
|
||||
// expired
|
||||
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour)
|
||||
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{})
|
||||
if expiredKey.IsValid() {
|
||||
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
|
||||
}
|
||||
|
||||
// revoked
|
||||
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour)
|
||||
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{})
|
||||
revokedKey.Revoked = true
|
||||
if revokedKey.IsValid() {
|
||||
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
|
||||
}
|
||||
|
||||
// overused
|
||||
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour)
|
||||
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{})
|
||||
overUsedKey.UsedTimes = 1
|
||||
if overUsedKey.IsValid() {
|
||||
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
|
||||
}
|
||||
|
||||
// overused
|
||||
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour)
|
||||
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{})
|
||||
reusableKey.UsedTimes = 99
|
||||
if !reusableKey.IsValid() {
|
||||
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
|
||||
@@ -71,7 +210,8 @@ func TestSetupKey_IsValid(t *testing.T) {
|
||||
}
|
||||
|
||||
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
|
||||
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string) {
|
||||
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string,
|
||||
expectedUpdatedAt time.Time, expectedAutoGroups []string) {
|
||||
if key.Name != expectedName {
|
||||
t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name)
|
||||
}
|
||||
@@ -92,6 +232,10 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
|
||||
t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.ExpiresAt)
|
||||
}
|
||||
|
||||
if key.UpdatedAt.Sub(expectedUpdatedAt).Round(time.Hour) != 0 {
|
||||
t.Errorf("expected setup key to have UpdatedAt ~ %v, got %v", expectedUpdatedAt, key.UpdatedAt)
|
||||
}
|
||||
|
||||
if key.CreatedAt.Sub(expectedCreatedAt).Round(time.Hour) != 0 {
|
||||
t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt)
|
||||
}
|
||||
@@ -104,13 +248,19 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
|
||||
if key.Id != strconv.Itoa(int(Hash(key.Key))) {
|
||||
t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id)
|
||||
}
|
||||
|
||||
if len(key.AutoGroups) != len(expectedAutoGroups) {
|
||||
t.Errorf("expected key AutoGroups size=%d, got %d", len(expectedAutoGroups), len(key.AutoGroups))
|
||||
}
|
||||
assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal")
|
||||
}
|
||||
|
||||
func TestSetupKey_Copy(t *testing.T) {
|
||||
|
||||
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour)
|
||||
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{})
|
||||
keyCopy := key.Copy()
|
||||
|
||||
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id)
|
||||
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id,
|
||||
key.UpdatedAt, key.AutoGroups)
|
||||
|
||||
}
|
||||
|
||||
@@ -2,19 +2,32 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
const (
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleUnknown UserRole = "unknown"
|
||||
)
|
||||
|
||||
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
|
||||
func StrRoleToUserRole(strRole string) UserRole {
|
||||
switch strings.ToLower(strRole) {
|
||||
case "admin":
|
||||
return UserRoleAdmin
|
||||
case "user":
|
||||
return UserRoleUser
|
||||
default:
|
||||
return UserRoleUnknown
|
||||
}
|
||||
}
|
||||
|
||||
// UserRole is the role of the User
|
||||
type UserRole string
|
||||
|
||||
@@ -22,20 +35,56 @@ type UserRole string
|
||||
type User struct {
|
||||
Id string
|
||||
Role UserRole
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string
|
||||
}
|
||||
|
||||
// toUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
autoGroups := u.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
|
||||
if userData == nil {
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: "",
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Copy the user
|
||||
func (u *User) Copy() *User {
|
||||
autoGroups := []string{}
|
||||
autoGroups = append(autoGroups, u.AutoGroups...)
|
||||
return &User{
|
||||
Id: u.Id,
|
||||
Role: u.Role,
|
||||
Id: u.Id,
|
||||
Role: u.Role,
|
||||
AutoGroups: autoGroups,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Role: role,
|
||||
Id: id,
|
||||
Role: role,
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +98,55 @@ func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin)
|
||||
}
|
||||
|
||||
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
|
||||
// Only User.AutoGroups field is allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
if update == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
if _, ok := account.Groups[newGroupID]; !ok {
|
||||
return nil,
|
||||
status.Errorf(codes.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
}
|
||||
}
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "update not found")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newUser := oldUser.Copy()
|
||||
newUser.AutoGroups = update.AutoGroups
|
||||
newUser.Role = update.Role
|
||||
|
||||
account.Users[newUser.Id] = newUser
|
||||
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
userData, err := am.lookupUserInCache(newUser, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newUser.toUserInfo(userData)
|
||||
}
|
||||
return newUser.toUserInfo(nil)
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
|
||||
am.mux.Lock()
|
||||
@@ -108,3 +206,50 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim
|
||||
|
||||
return user.Role == UserRoleAdmin, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getUsersInfos(account *Account) ([]*UserInfo, error) {
|
||||
var err error
|
||||
queriedUsers := make([]*idp.UserData, 0)
|
||||
if !isNil(am.idpManager) {
|
||||
queriedUsers, err = am.lookupCache(account.Users, account.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
userInfos := make([]*UserInfo, 0)
|
||||
|
||||
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
||||
if len(queriedUsers) == 0 {
|
||||
for _, user := range account.Users {
|
||||
info, err := user.toUserInfo(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userInfos = append(userInfos, info)
|
||||
}
|
||||
return userInfos, nil
|
||||
}
|
||||
|
||||
for _, queriedUser := range queriedUsers {
|
||||
if localUser, contains := account.Users[queriedUser.ID]; contains {
|
||||
|
||||
info, err := localUser.toUserInfo(queriedUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userInfos = append(userInfos, info)
|
||||
}
|
||||
}
|
||||
|
||||
return userInfos, nil
|
||||
}
|
||||
|
||||
// GetUsers performs a batched request for users from IDP by account ID
|
||||
func (am *DefaultAccountManager) GetUsers(accountID string) ([]*UserInfo, error) {
|
||||
account, err := am.GetAccountById(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.getUsersInfos(account)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"io"
|
||||
@@ -41,13 +42,15 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
|
||||
}
|
||||
|
||||
// MarshalCredential marsharl a Credential instance and returns a Message object
|
||||
func MarshalCredential(myKey wgtypes.Key, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type) (*proto.Message, error) {
|
||||
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type) (*proto.Message, error) {
|
||||
return &proto.Message{
|
||||
Key: myKey.PublicKey().String(),
|
||||
RemoteKey: remoteKey.String(),
|
||||
Body: &proto.Body{
|
||||
Type: t,
|
||||
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
|
||||
Type: t,
|
||||
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
|
||||
WgListenPort: uint32(myPort),
|
||||
NetBirdVersion: system.NetbirdVersion(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
4
signal/proto/generate.sh
Executable file
4
signal/proto/generate.sh
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||
protoc -I proto/ proto/signalexchange.proto --go_out=. --go-grpc_out=.
|
||||
@@ -214,6 +214,9 @@ type Body struct {
|
||||
|
||||
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
|
||||
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
|
||||
// wgListenPort is an actual WireGuard listen port
|
||||
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
|
||||
NetBirdVersion string `protobuf:"bytes,4,opt,name=netBirdVersion,proto3" json:"netBirdVersion,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Body) Reset() {
|
||||
@@ -262,6 +265,20 @@ func (x *Body) GetPayload() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Body) GetWgListenPort() uint32 {
|
||||
if x != nil {
|
||||
return x.WgListenPort
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *Body) GetNetBirdVersion() string {
|
||||
if x != nil {
|
||||
return x.NetBirdVersion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_signalexchange_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_signalexchange_proto_rawDesc = []byte{
|
||||
@@ -281,28 +298,32 @@ var file_signalexchange_proto_rawDesc = []byte{
|
||||
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
|
||||
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
|
||||
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
|
||||
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0x7d, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x0a,
|
||||
0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x69,
|
||||
0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64,
|
||||
0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, 0x07,
|
||||
0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70,
|
||||
0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x2c, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
|
||||
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
|
||||
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
|
||||
0x54, 0x45, 0x10, 0x02, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45,
|
||||
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12,
|
||||
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
|
||||
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
||||
0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
|
||||
0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
|
||||
0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74,
|
||||
0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
|
||||
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
|
||||
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
|
||||
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
|
||||
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01,
|
||||
0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x33,
|
||||
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xc9, 0x01, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
|
||||
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
|
||||
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
|
||||
0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
|
||||
0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x22, 0x0a, 0x0c, 0x77, 0x67, 0x4c, 0x69, 0x73,
|
||||
0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x77,
|
||||
0x67, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x6e,
|
||||
0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73,
|
||||
0x69, 0x6f, 0x6e, 0x22, 0x2c, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f,
|
||||
0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52,
|
||||
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10,
|
||||
0x02, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68,
|
||||
0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e,
|
||||
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20,
|
||||
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
|
||||
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||
0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72,
|
||||
0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68,
|
||||
0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65,
|
||||
0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
|
||||
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
|
||||
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a,
|
||||
0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -49,4 +49,7 @@ message Body {
|
||||
}
|
||||
Type type = 1;
|
||||
string payload = 2;
|
||||
// wgListenPort is an actual WireGuard listen port
|
||||
uint32 wgListenPort = 3;
|
||||
string netBirdVersion = 4;
|
||||
}
|
||||
Reference in New Issue
Block a user