mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Compare commits
16 Commits
feature/in
...
feature/ap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5d197cd5f9 | ||
|
|
6bee984b46 | ||
|
|
2ee7d69f80 | ||
|
|
af69a48745 | ||
|
|
68ff97ba84 | ||
|
|
c5705803a5 | ||
|
|
7e1ae448e0 | ||
|
|
518a2561a2 | ||
|
|
c75ffd0f4b | ||
|
|
e4ad6174ca | ||
|
|
6de313070a | ||
|
|
cd7d1a80c9 | ||
|
|
be7d829858 | ||
|
|
ed1872560f | ||
|
|
de898899a4 | ||
|
|
b63ec71aed |
52
.github/workflows/golang-test-linux.yml
vendored
52
.github/workflows/golang-test-linux.yml
vendored
@@ -33,3 +33,55 @@ jobs:
|
|||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||||
|
|
||||||
|
test_client_on_docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v2
|
||||||
|
with:
|
||||||
|
go-version: 1.18.x
|
||||||
|
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libappindicator3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: Generate Iface Test bin
|
||||||
|
run: go test -c -o iface-testing.bin ./iface/...
|
||||||
|
|
||||||
|
- name: Generate RouteManager Test bin
|
||||||
|
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||||
|
|
||||||
|
- name: Generate Engine Test bin
|
||||||
|
run: go test -c -o engine-testing.bin ./client/internal/*.go
|
||||||
|
|
||||||
|
- name: Generate Peer Test bin
|
||||||
|
run: go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||||
|
|
||||||
|
- run: chmod +x *testing.bin
|
||||||
|
|
||||||
|
- name: Run Iface tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin
|
||||||
|
|
||||||
|
- name: Run RouteManager tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin
|
||||||
|
|
||||||
|
- name: Run Engine tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin
|
||||||
|
|
||||||
|
- name: Run Peer tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin
|
||||||
28
README.md
28
README.md
@@ -16,7 +16,7 @@
|
|||||||
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&utm_medium=referral&utm_content=netbirdio/netbird&utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
|
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&utm_medium=referral&utm_content=netbirdio/netbird&utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
|
||||||
<img src="https://img.shields.io/badge/slack-@wiretrustee-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
@@ -43,20 +43,20 @@ It requires zero configuration effort leaving behind the hassle of opening ports
|
|||||||
NetBird creates an overlay peer-to-peer network connecting machines automatically regardless of their location (home, office, datacenter, container, cloud or edge environments) unifying virtual private network management experience.
|
NetBird creates an overlay peer-to-peer network connecting machines automatically regardless of their location (home, office, datacenter, container, cloud or edge environments) unifying virtual private network management experience.
|
||||||
|
|
||||||
**Key features:**
|
**Key features:**
|
||||||
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
|
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
|
||||||
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
|
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
|
||||||
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
|
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
|
||||||
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
|
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
|
||||||
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
|
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
|
||||||
- \[x] Multiuser support - sharing network between multiple users.
|
- \[x] Multiuser support - sharing network between multiple users.
|
||||||
- \[x] SSO and MFA support.
|
- \[x] SSO and MFA support.
|
||||||
- \[x] Multicloud and hybrid-cloud support.
|
- \[x] Multicloud and hybrid-cloud support.
|
||||||
- \[x] Kernel WireGuard usage when possible.
|
- \[x] Kernel WireGuard usage when possible.
|
||||||
- \[x] Access Controls - groups & rules.
|
- \[x] Access Controls - groups & rules.
|
||||||
- \[x] Remote SSH access without managing SSH keys.
|
- \[x] Remote SSH access without managing SSH keys.
|
||||||
|
- \[x] Network Routes.
|
||||||
|
|
||||||
**Coming soon:**
|
**Coming soon:**
|
||||||
- \[ ] Network Routes.
|
|
||||||
- \[ ] Private DNS.
|
- \[ ] Private DNS.
|
||||||
- \[ ] Mobile clients.
|
- \[ ] Mobile clients.
|
||||||
- \[ ] Network Activity Monitoring.
|
- \[ ] Network Activity Monitoring.
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -18,6 +19,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
detailFlag bool
|
detailFlag bool
|
||||||
|
ipv4Flag bool
|
||||||
ipsFilter []string
|
ipsFilter []string
|
||||||
statusFilter string
|
statusFilter string
|
||||||
ipsFilterMap map[string]struct{}
|
ipsFilterMap map[string]struct{}
|
||||||
@@ -73,7 +75,7 @@ var statusCmd = &cobra.Command{
|
|||||||
pbFullStatus := resp.GetFullStatus()
|
pbFullStatus := resp.GetFullStatus()
|
||||||
fullStatus := fromProtoFullStatus(pbFullStatus)
|
fullStatus := fromProtoFullStatus(pbFullStatus)
|
||||||
|
|
||||||
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion()))
|
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion(), ipv4Flag))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -82,8 +84,9 @@ var statusCmd = &cobra.Command{
|
|||||||
func init() {
|
func init() {
|
||||||
ipsFilterMap = make(map[string]struct{})
|
ipsFilterMap = make(map[string]struct{})
|
||||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information")
|
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g. --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g. --filter-by-status connected")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
@@ -142,7 +145,19 @@ func fromProtoFullStatus(pbFullStatus *proto.FullStatus) nbStatus.FullStatus {
|
|||||||
return fullStatus
|
return fullStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string) string {
|
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string, flag bool) string {
|
||||||
|
|
||||||
|
interfaceIP := fullStatus.LocalPeerState.IP
|
||||||
|
|
||||||
|
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipv4Flag {
|
||||||
|
return fmt.Sprintf("%s\n", ip)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
managementStatusURL = ""
|
managementStatusURL = ""
|
||||||
signalStatusURL = ""
|
signalStatusURL = ""
|
||||||
@@ -164,8 +179,6 @@ func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonSta
|
|||||||
signalConnString = "Connected"
|
signalConnString = "Connected"
|
||||||
}
|
}
|
||||||
|
|
||||||
interfaceIP := fullStatus.LocalPeerState.IP
|
|
||||||
|
|
||||||
if fullStatus.LocalPeerState.KernelInterface {
|
if fullStatus.LocalPeerState.KernelInterface {
|
||||||
interfaceTypeString = "Kernel"
|
interfaceTypeString = "Kernel"
|
||||||
} else if fullStatus.LocalPeerState.IP == "" {
|
} else if fullStatus.LocalPeerState.IP == "" {
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return DeviceAuthorizationFlow{
|
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||||
|
|
||||||
ProviderConfig: ProviderConfig{
|
ProviderConfig: ProviderConfig{
|
||||||
@@ -274,5 +274,32 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
|
|||||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||||
},
|
},
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||||
|
if err != nil {
|
||||||
|
return DeviceAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return deviceAuthorizationFlow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProviderConfigValid(config ProviderConfig) error {
|
||||||
|
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
if config.Audience == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||||
|
}
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.ClientSecret == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Client Secret")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.DeviceAuthEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
localPeerState := nbStatus.LocalPeerState{
|
localPeerState := nbStatus.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: iface.WireguardModExists(),
|
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||||
}
|
}
|
||||||
|
|
||||||
statusRecorder.UpdateLocalPeerState(localPeerState)
|
statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
|||||||
@@ -36,7 +36,10 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
|||||||
defer func() {
|
defer func() {
|
||||||
err = mgmClient.Close()
|
err = mgmClient.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
cStatus, ok := status.FromError(err)
|
||||||
|
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||||
|
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String())
|
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||||
c.chosenRoute.Network.String(), c.wgInterface.GetAddress().IP.String(), err)
|
c.network.String(), c.wgInterface.GetAddress().IP.String(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -84,8 +84,10 @@ func (n *nftablesManager) CleanRoutingRules() {
|
|||||||
n.mux.Lock()
|
n.mux.Lock()
|
||||||
defer n.mux.Unlock()
|
defer n.mux.Unlock()
|
||||||
log.Debug("flushing tables")
|
log.Debug("flushing tables")
|
||||||
n.conn.FlushTable(n.tableIPv6)
|
if n.tableIPv4 != nil && n.tableIPv6 != nil {
|
||||||
n.conn.FlushTable(n.tableIPv4)
|
n.conn.FlushTable(n.tableIPv6)
|
||||||
|
n.conn.FlushTable(n.tableIPv4)
|
||||||
|
}
|
||||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
12
go.mod
12
go.mod
@@ -11,7 +11,7 @@ require (
|
|||||||
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7 //keep this version otherwise wiretrustee up command breaks
|
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7 //keep this version otherwise wiretrustee up command breaks
|
||||||
github.com/onsi/ginkgo v1.16.5
|
github.com/onsi/ginkgo v1.16.5
|
||||||
github.com/onsi/gomega v1.18.1
|
github.com/onsi/gomega v1.18.1
|
||||||
github.com/pion/ice/v2 v2.1.17
|
github.com/pion/ice/v2 v2.2.7
|
||||||
github.com/rs/cors v1.8.0
|
github.com/rs/cors v1.8.0
|
||||||
github.com/sirupsen/logrus v1.8.1
|
github.com/sirupsen/logrus v1.8.1
|
||||||
github.com/spf13/cobra v1.3.0
|
github.com/spf13/cobra v1.3.0
|
||||||
@@ -42,7 +42,7 @@ require (
|
|||||||
github.com/rs/xid v1.3.0
|
github.com/rs/xid v1.3.0
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
github.com/stretchr/testify v1.7.1
|
github.com/stretchr/testify v1.7.1
|
||||||
golang.org/x/net v0.0.0-20220513224357-95641704303c
|
golang.org/x/net v0.0.0-20220630215102-69896b714898
|
||||||
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
|
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,13 +80,13 @@ require (
|
|||||||
github.com/nxadm/tail v1.4.8 // indirect
|
github.com/nxadm/tail v1.4.8 // indirect
|
||||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||||
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
||||||
github.com/pion/dtls/v2 v2.1.2 // indirect
|
github.com/pion/dtls/v2 v2.1.5 // indirect
|
||||||
github.com/pion/logging v0.2.2 // indirect
|
github.com/pion/logging v0.2.2 // indirect
|
||||||
github.com/pion/mdns v0.0.5 // indirect
|
github.com/pion/mdns v0.0.5 // indirect
|
||||||
github.com/pion/randutil v0.1.0 // indirect
|
github.com/pion/randutil v0.1.0 // indirect
|
||||||
github.com/pion/stun v0.3.5 // indirect
|
github.com/pion/stun v0.3.5 // indirect
|
||||||
github.com/pion/transport v0.13.0 // indirect
|
github.com/pion/transport v0.13.1 // indirect
|
||||||
github.com/pion/turn/v2 v2.0.7 // indirect
|
github.com/pion/turn/v2 v2.0.8 // indirect
|
||||||
github.com/pion/udp v0.1.1 // indirect
|
github.com/pion/udp v0.1.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/prometheus/client_golang v1.12.2 // indirect
|
github.com/prometheus/client_golang v1.12.2 // indirect
|
||||||
@@ -117,6 +117,4 @@ require (
|
|||||||
k8s.io/apimachinery v0.23.5 // indirect
|
k8s.io/apimachinery v0.23.5 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb
|
|
||||||
|
|
||||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84
|
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84
|
||||||
|
|||||||
25
go.sum
25
go.sum
@@ -505,8 +505,10 @@ github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTK
|
|||||||
github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4=
|
github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4=
|
||||||
github.com/pegasus-kv/thrift v0.13.0/go.mod h1:Gl9NT/WHG6ABm6NsrbfE8LiJN0sAyneCrvB4qN4NPqQ=
|
github.com/pegasus-kv/thrift v0.13.0/go.mod h1:Gl9NT/WHG6ABm6NsrbfE8LiJN0sAyneCrvB4qN4NPqQ=
|
||||||
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
|
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
|
||||||
github.com/pion/dtls/v2 v2.1.2 h1:22Q1Jk9L++Yo7BIf9130MonNPfPVb+YgdYLeyQotuAA=
|
github.com/pion/dtls/v2 v2.1.5 h1:jlh2vtIyUBShchoTDqpCCqiYCyRFJ/lvf/gQ8TALs+c=
|
||||||
github.com/pion/dtls/v2 v2.1.2/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus=
|
github.com/pion/dtls/v2 v2.1.5/go.mod h1:BqCE7xPZbPSubGasRoDFJeTsyJtdD1FanJYL0JGheqY=
|
||||||
|
github.com/pion/ice/v2 v2.2.7 h1:kG9tux3WdYUSqqqnf+O5zKlpy41PdlvLUBlYJeV2emQ=
|
||||||
|
github.com/pion/ice/v2 v2.2.7/go.mod h1:Ckj7cWZ717rtU01YoDQA9ntGWCk95D42uVZ8sI0EL+8=
|
||||||
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
|
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
|
||||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||||
github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw=
|
github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw=
|
||||||
@@ -516,10 +518,11 @@ github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TB
|
|||||||
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
|
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
|
||||||
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
|
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
|
||||||
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
|
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
|
||||||
github.com/pion/transport v0.13.0 h1:KWTA5ZrQogizzYwPEciGtHPLwpAjE91FgXnyu+Hv2uY=
|
|
||||||
github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g=
|
github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g=
|
||||||
github.com/pion/turn/v2 v2.0.7 h1:SZhc00WDovK6czaN1RSiHqbwANtIO6wfZQsU0m0KNE8=
|
github.com/pion/transport v0.13.1 h1:/UH5yLeQtwm2VZIPjxwnNFxjS4DFhyLfS4GlfuKUzfA=
|
||||||
github.com/pion/turn/v2 v2.0.7/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
|
github.com/pion/transport v0.13.1/go.mod h1:EBxbqzyv+ZrmDb82XswEE0BjfQFtuw1Nu6sjnjWCsGg=
|
||||||
|
github.com/pion/turn/v2 v2.0.8 h1:KEstL92OUN3k5k8qxsXHpr7WWfrdp7iJZHx99ud8muw=
|
||||||
|
github.com/pion/turn/v2 v2.0.8/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
|
||||||
github.com/pion/udp v0.1.1 h1:8UAPvyqmsxK8oOjloDk4wUt63TzFe9WEJkg5lChlj7o=
|
github.com/pion/udp v0.1.1 h1:8UAPvyqmsxK8oOjloDk4wUt63TzFe9WEJkg5lChlj7o=
|
||||||
github.com/pion/udp v0.1.1/go.mod h1:6AFo+CMdKQm7UiA0eUPA8/eVCTx8jBIITLZHc9DWX5M=
|
github.com/pion/udp v0.1.1/go.mod h1:6AFo+CMdKQm7UiA0eUPA8/eVCTx8jBIITLZHc9DWX5M=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
@@ -624,8 +627,6 @@ github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJ
|
|||||||
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
|
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
|
||||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
||||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
|
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
|
||||||
github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb h1:CU1/+CEeCPvYXgfAyqTJXSQSf6hW3wsWM6Dfz6HkHEQ=
|
|
||||||
github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb/go.mod h1:XT1Nrb4OxbVFPffbQMbq4PaeEkpRLVzdphh3fjrw7DY=
|
|
||||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
@@ -662,7 +663,7 @@ golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5y
|
|||||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c=
|
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c=
|
||||||
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
@@ -772,8 +773,10 @@ golang.org/x/net v0.0.0-20211208012354-db4efeb81f4b/go.mod h1:9nx3DQGgdP8bBQD5qx
|
|||||||
golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||||
golang.org/x/net v0.0.0-20220513224357-95641704303c h1:nF9mHSvoKBLkQNQhJZNsc66z2UzAMUbLGjC95CF3pU0=
|
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||||
golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
|
golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw=
|
||||||
|
golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
@@ -906,6 +909,8 @@ golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20211214234402-4825e8c3871d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20211214234402-4825e8c3871d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU=
|
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU=
|
||||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func (w *WGIface) assignAddr() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||||
func WireguardModExists() bool {
|
func WireguardModuleIsLoaded() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,48 +1,29 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"fmt"
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
type NativeLink struct {
|
type NativeLink struct {
|
||||||
Link *netlink.Link
|
Link *netlink.Link
|
||||||
}
|
}
|
||||||
|
|
||||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
|
||||||
func WireguardModExists() bool {
|
|
||||||
link := newWGLink("mustnotexist")
|
|
||||||
|
|
||||||
// We willingly try to create a device with an invalid
|
|
||||||
// MTU here as the validation of the MTU will be performed after
|
|
||||||
// the validation of the link kind and hence allows us to check
|
|
||||||
// for the existance of the wireguard module without actually
|
|
||||||
// creating a link.
|
|
||||||
//
|
|
||||||
// As a side-effect, this will also let the kernel lazy-load
|
|
||||||
// the wireguard module.
|
|
||||||
link.attrs.MTU = math.MaxInt
|
|
||||||
|
|
||||||
err := netlink.LinkAdd(link)
|
|
||||||
|
|
||||||
return errors.Is(err, syscall.EINVAL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
func (w *WGIface) Create() error {
|
func (w *WGIface) Create() error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
if WireguardModExists() {
|
if WireguardModuleIsLoaded() {
|
||||||
log.Info("using kernel WireGuard")
|
log.Info("using kernel WireGuard")
|
||||||
return w.createWithKernel()
|
return w.createWithKernel()
|
||||||
} else {
|
} else {
|
||||||
|
if !tunModuleIsLoaded() {
|
||||||
|
return fmt.Errorf("couldn't check or load tun module")
|
||||||
|
}
|
||||||
log.Info("using userspace WireGuard")
|
log.Info("using userspace WireGuard")
|
||||||
return w.createWithUserspace()
|
return w.createWithUserspace()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
return w.assignAddr(luid)
|
return w.assignAddr(luid)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WireguardModExists check if we can load wireguard mod (linux only)
|
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||||
func WireguardModExists() bool {
|
func WireguardModuleIsLoaded() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
349
iface/module_linux.go
Normal file
349
iface/module_linux.go
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
// Package iface provides wireguard network interface creation and management
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"io/fs"
|
||||||
|
"io/ioutil"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Holds logic to check existence of kernel modules used by wireguard interfaces
|
||||||
|
// Copied from https://github.com/paultag/go-modprobe and
|
||||||
|
// https://github.com/pmorjan/kmod
|
||||||
|
|
||||||
|
type status int
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultModuleDir = "/lib/modules"
|
||||||
|
unknown status = iota
|
||||||
|
unloaded
|
||||||
|
unloading
|
||||||
|
loading
|
||||||
|
live
|
||||||
|
inuse
|
||||||
|
)
|
||||||
|
|
||||||
|
type module struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrModuleNotFound is the error resulting if a module can't be found.
|
||||||
|
ErrModuleNotFound = errors.New("module not found")
|
||||||
|
moduleLibDir = defaultModuleDir
|
||||||
|
// get the root directory for the kernel modules. If this line panics,
|
||||||
|
// it's because getModuleRoot has failed to get the uname of the running
|
||||||
|
// kernel (likely a non-POSIX system, but maybe a broken kernel?)
|
||||||
|
moduleRoot = getModuleRoot()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get the module root (/lib/modules/$(uname -r)/)
|
||||||
|
func getModuleRoot() string {
|
||||||
|
uname := unix.Utsname{}
|
||||||
|
if err := unix.Uname(&uname); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for ; uname.Release[i] != 0; i++ {
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
|
||||||
|
func tunModuleIsLoaded() bool {
|
||||||
|
_, err := os.Stat("/dev/net/tun")
|
||||||
|
if err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("couldn't access device /dev/net/tun, go error %v, "+
|
||||||
|
"will attempt to load tun module, if running on container add flag --cap-add=NET_ADMIN", err)
|
||||||
|
|
||||||
|
tunLoaded, err := tryToLoadModule("tun")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to find or load tun module, got error: %v", err)
|
||||||
|
}
|
||||||
|
return tunLoaded
|
||||||
|
}
|
||||||
|
|
||||||
|
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||||
|
func WireguardModuleIsLoaded() bool {
|
||||||
|
if canCreateFakeWireguardInterface() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := tryToLoadModule("wireguard")
|
||||||
|
if err != nil {
|
||||||
|
log.Info(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
func canCreateFakeWireguardInterface() bool {
|
||||||
|
link := newWGLink("mustnotexist")
|
||||||
|
|
||||||
|
// We willingly try to create a device with an invalid
|
||||||
|
// MTU here as the validation of the MTU will be performed after
|
||||||
|
// the validation of the link kind and hence allows us to check
|
||||||
|
// for the existance of the wireguard module without actually
|
||||||
|
// creating a link.
|
||||||
|
//
|
||||||
|
// As a side-effect, this will also let the kernel lazy-load
|
||||||
|
// the wireguard module.
|
||||||
|
link.attrs.MTU = math.MaxInt
|
||||||
|
|
||||||
|
err := netlink.LinkAdd(link)
|
||||||
|
|
||||||
|
return errors.Is(err, syscall.EINVAL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryToLoadModule(moduleName string) (bool, error) {
|
||||||
|
if isModuleEnabled(moduleName) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
modulePath, err := getModulePath(moduleName)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("couldn't find module path for %s, error: %v", moduleName, err)
|
||||||
|
}
|
||||||
|
if modulePath == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("trying to load %s module", moduleName)
|
||||||
|
|
||||||
|
err = loadModuleWithDependencies(moduleName, modulePath)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("couldn't load %s module, error: %v", moduleName, err)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isModuleEnabled(name string) bool {
|
||||||
|
builtin, builtinErr := isBuiltinModule(name)
|
||||||
|
state, statusErr := moduleStatus(name)
|
||||||
|
return (builtinErr == nil && builtin) || (statusErr == nil && state >= loading)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getModulePath(name string) (string, error) {
|
||||||
|
var foundPath string
|
||||||
|
skipRemainingDirs := false
|
||||||
|
|
||||||
|
err := filepath.WalkDir(
|
||||||
|
moduleRoot,
|
||||||
|
func(path string, info fs.DirEntry, err error) error {
|
||||||
|
if skipRemainingDirs {
|
||||||
|
return fs.SkipDir
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// skip broken files
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.Type().IsRegular() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nameFromPath := pathToName(path)
|
||||||
|
if nameFromPath == name {
|
||||||
|
foundPath = path
|
||||||
|
skipRemainingDirs = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return foundPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pathToName(s string) string {
|
||||||
|
s = filepath.Base(s)
|
||||||
|
for ext := filepath.Ext(s); ext != ""; ext = filepath.Ext(s) {
|
||||||
|
s = strings.TrimSuffix(s, ext)
|
||||||
|
}
|
||||||
|
return cleanName(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanName(s string) string {
|
||||||
|
return strings.ReplaceAll(strings.TrimSpace(s), "-", "_")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBuiltinModule(name string) (bool, error) {
|
||||||
|
f, err := os.Open(filepath.Join(moduleRoot, "/modules.builtin"))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed closing modules.builtin file, %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var found bool
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if pathToName(line) == name {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return found, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// /proc/modules
|
||||||
|
// name | memory size | reference count | references | state: <Live|Loading|Unloading>
|
||||||
|
// macvlan 28672 1 macvtap, Live 0x0000000000000000
|
||||||
|
func moduleStatus(name string) (status, error) {
|
||||||
|
state := unknown
|
||||||
|
f, err := os.Open("/proc/modules")
|
||||||
|
if err != nil {
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed closing /proc/modules file, %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
state = unloaded
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
fields := strings.Fields(scanner.Text())
|
||||||
|
if fields[0] == name {
|
||||||
|
if fields[2] != "0" {
|
||||||
|
state = inuse
|
||||||
|
break
|
||||||
|
}
|
||||||
|
switch fields[4] {
|
||||||
|
case "Live":
|
||||||
|
state = live
|
||||||
|
case "Loading":
|
||||||
|
state = loading
|
||||||
|
case "Unloading":
|
||||||
|
state = unloading
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadModuleWithDependencies(name, path string) error {
|
||||||
|
deps, err := getModuleDependencies(name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't load list of module %s dependecies", name)
|
||||||
|
}
|
||||||
|
for _, dep := range deps {
|
||||||
|
err = loadModule(dep.name, dep.path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't load dependecy module %s for %s", dep.name, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return loadModule(name, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadModule(name, path string) error {
|
||||||
|
state, err := moduleStatus(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if state >= loading {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed closing %s file, %v", path, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// first try finit_module(2), then init_module(2)
|
||||||
|
err = unix.FinitModule(int(f.Fd()), "", 0)
|
||||||
|
if errors.Is(err, unix.ENOSYS) {
|
||||||
|
buf, err := ioutil.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return unix.InitModule(buf, "")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getModuleDependencies returns a module dependencies
|
||||||
|
func getModuleDependencies(name string) ([]module, error) {
|
||||||
|
f, err := os.Open(filepath.Join(moduleRoot, "/modules.dep"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed closing modules.dep file, %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var deps []string
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if pathToName(strings.TrimSuffix(fields[0], ":")) == name {
|
||||||
|
deps = fields
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(deps) == 0 {
|
||||||
|
return nil, ErrModuleNotFound
|
||||||
|
}
|
||||||
|
deps[0] = strings.TrimSuffix(deps[0], ":")
|
||||||
|
|
||||||
|
var modules []module
|
||||||
|
for _, v := range deps {
|
||||||
|
if pathToName(v) != name {
|
||||||
|
modules = append(modules, module{
|
||||||
|
name: pathToName(v),
|
||||||
|
path: filepath.Join(moduleRoot, v),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modules, nil
|
||||||
|
}
|
||||||
221
iface/module_linux_test.go
Normal file
221
iface/module_linux_test.go
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetModuleDependencies(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
module string
|
||||||
|
expected []module
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get Single Dependency",
|
||||||
|
module: "bar",
|
||||||
|
expected: []module{
|
||||||
|
{name: "foo", path: "kernel/a/foo.ko"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get Multiple Dependencies",
|
||||||
|
module: "baz",
|
||||||
|
expected: []module{
|
||||||
|
{name: "foo", path: "kernel/a/foo.ko"},
|
||||||
|
{name: "bar", path: "kernel/a/bar.ko"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get No Dependencies",
|
||||||
|
module: "foo",
|
||||||
|
expected: []module{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
defer resetGlobals()
|
||||||
|
_, _ = createFiles(t)
|
||||||
|
modules, err := getModuleDependencies(testCase.module)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expected := testCase.expected
|
||||||
|
for i := range expected {
|
||||||
|
expected[i].path = moduleRoot + "/" + expected[i].path
|
||||||
|
}
|
||||||
|
|
||||||
|
require.ElementsMatchf(t, modules, expected, "returned modules should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsBuiltinModule(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
module string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Built In Should Return True",
|
||||||
|
module: "foo_bi",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not Built In Should Return False",
|
||||||
|
module: "not_built_in",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
defer resetGlobals()
|
||||||
|
_, _ = createFiles(t)
|
||||||
|
|
||||||
|
isBuiltIn, err := isBuiltinModule(testCase.module)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, testCase.expected, isBuiltIn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModuleStatus(t *testing.T) {
|
||||||
|
random, err := getRandomLoadedModule(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("should be able to get random module")
|
||||||
|
}
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
module string
|
||||||
|
shouldBeLoaded bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should Return Module Loading Or Greater Status",
|
||||||
|
module: random,
|
||||||
|
shouldBeLoaded: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Return Module Unloaded Or Lower Status",
|
||||||
|
module: "not_loaded_module",
|
||||||
|
shouldBeLoaded: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
defer resetGlobals()
|
||||||
|
_, _ = createFiles(t)
|
||||||
|
|
||||||
|
state, err := moduleStatus(testCase.module)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if testCase.shouldBeLoaded {
|
||||||
|
require.GreaterOrEqual(t, loading, state, "moduleStatus for %s should return state loading", testCase.module)
|
||||||
|
} else {
|
||||||
|
require.Less(t, state, loading, "module should return state unloading or lower")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetGlobals() {
|
||||||
|
moduleLibDir = defaultModuleDir
|
||||||
|
moduleRoot = getModuleRoot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createFiles(t *testing.T) (string, []module) {
|
||||||
|
writeFile := func(path, text string) {
|
||||||
|
if err := ioutil.WriteFile(path, []byte(text), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var u unix.Utsname
|
||||||
|
if err := unix.Uname(&u); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLibDir = t.TempDir()
|
||||||
|
|
||||||
|
moduleRoot = getModuleRoot()
|
||||||
|
if err := os.Mkdir(moduleRoot, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
text := "kernel/a/foo.ko:\n"
|
||||||
|
text += "kernel/a/bar.ko: kernel/a/foo.ko\n"
|
||||||
|
text += "kernel/a/baz.ko: kernel/a/bar.ko kernel/a/foo.ko\n"
|
||||||
|
writeFile(filepath.Join(moduleRoot, "/modules.dep"), text)
|
||||||
|
|
||||||
|
text = "kernel/a/foo_bi.ko\n"
|
||||||
|
text += "kernel/a/bar-bi.ko.gz\n"
|
||||||
|
writeFile(filepath.Join(moduleRoot, "/modules.builtin"), text)
|
||||||
|
|
||||||
|
modules := []module{
|
||||||
|
{name: "foo", path: "kernel/a/foo.ko"},
|
||||||
|
{name: "bar", path: "kernel/a/bar.ko"},
|
||||||
|
{name: "baz", path: "kernel/a/baz.ko"},
|
||||||
|
}
|
||||||
|
return moduleLibDir, modules
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRandomLoadedModule(t *testing.T) (string, error) {
|
||||||
|
f, err := os.Open("/proc/modules")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := f.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("failed closing /proc/modules file, %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
lines, err := lineCounter(f)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
counter := 1
|
||||||
|
midLine := lines / 2
|
||||||
|
modName := ""
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
fields := strings.Fields(scanner.Text())
|
||||||
|
if counter == midLine {
|
||||||
|
if fields[4] == "Unloading" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modName = fields[0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
counter++
|
||||||
|
}
|
||||||
|
if scanner.Err() != nil {
|
||||||
|
return "", scanner.Err()
|
||||||
|
}
|
||||||
|
return modName, nil
|
||||||
|
}
|
||||||
|
func lineCounter(r io.Reader) (int, error) {
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
count := 0
|
||||||
|
lineSep := []byte{'\n'}
|
||||||
|
|
||||||
|
for {
|
||||||
|
c, err := r.Read(buf)
|
||||||
|
count += bytes.Count(buf[:c], lineSep)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err == io.EOF:
|
||||||
|
return count, nil
|
||||||
|
|
||||||
|
case err != nil:
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -31,20 +31,23 @@ const (
|
|||||||
type AccountManager interface {
|
type AccountManager interface {
|
||||||
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
|
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
|
||||||
GetAccountByUser(userId string) (*Account, error)
|
GetAccountByUser(userId string) (*Account, error)
|
||||||
AddSetupKey(
|
CreateSetupKey(
|
||||||
accountId string,
|
accountId string,
|
||||||
keyName string,
|
keyName string,
|
||||||
keyType SetupKeyType,
|
keyType SetupKeyType,
|
||||||
expiresIn time.Duration,
|
expiresIn time.Duration,
|
||||||
|
autoGroups []string,
|
||||||
) (*SetupKey, error)
|
) (*SetupKey, error)
|
||||||
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
|
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
|
||||||
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
|
SaveUser(accountID string, key *User) (*UserInfo, error)
|
||||||
|
GetSetupKey(accountID, keyID string) (*SetupKey, error)
|
||||||
GetAccountById(accountId string) (*Account, error)
|
GetAccountById(accountId string) (*Account, error)
|
||||||
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
|
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
|
||||||
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error)
|
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error)
|
||||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||||
AccountExists(accountId string) (*bool, error)
|
AccountExists(accountId string) (*bool, error)
|
||||||
GetPeer(peerKey string) (*Peer, error)
|
GetPeer(peerKey string) (*Peer, error)
|
||||||
|
GetPeers(accountID string) ([]*PeerInfo, error)
|
||||||
MarkPeerConnected(peerKey string, connected bool) error
|
MarkPeerConnected(peerKey string, connected bool) error
|
||||||
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
|
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
|
||||||
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
||||||
@@ -55,7 +58,7 @@ type AccountManager interface {
|
|||||||
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error)
|
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error)
|
||||||
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
|
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
|
||||||
UpdatePeerSSHKey(peerKey string, sshKey string) error
|
UpdatePeerSSHKey(peerKey string, sshKey string) error
|
||||||
GetUsersFromAccount(accountId string) ([]*UserInfo, error)
|
GetUsers(accountId string) ([]*UserInfo, error)
|
||||||
GetGroup(accountId, groupID string) (*Group, error)
|
GetGroup(accountId, groupID string) (*Group, error)
|
||||||
SaveGroup(accountId string, group *Group) error
|
SaveGroup(accountId string, group *Group) error
|
||||||
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
||||||
@@ -75,6 +78,7 @@ type AccountManager interface {
|
|||||||
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
|
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
|
||||||
DeleteRoute(accountID, routeID string) error
|
DeleteRoute(accountID, routeID string) error
|
||||||
ListRoutes(accountID string) ([]*route.Route, error)
|
ListRoutes(accountID string) ([]*route.Route, error)
|
||||||
|
ListSetupKeys(accountID string) ([]*SetupKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@@ -105,10 +109,11 @@ type Account struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UserInfo struct {
|
type UserInfo struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
|
AutoGroups []string `json:"auto_groups"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) Copy() *Account {
|
func (a *Account) Copy() *Account {
|
||||||
@@ -244,93 +249,6 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
|
|
||||||
func (am *DefaultAccountManager) AddSetupKey(
|
|
||||||
accountId string,
|
|
||||||
keyName string,
|
|
||||||
keyType SetupKeyType,
|
|
||||||
expiresIn time.Duration,
|
|
||||||
) (*SetupKey, error) {
|
|
||||||
am.mux.Lock()
|
|
||||||
defer am.mux.Unlock()
|
|
||||||
|
|
||||||
keyDuration := DefaultSetupKeyDuration
|
|
||||||
if expiresIn != 0 {
|
|
||||||
keyDuration = expiresIn
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
setupKey := GenerateSetupKey(keyName, keyType, keyDuration)
|
|
||||||
account.SetupKeys[setupKey.Key] = setupKey
|
|
||||||
|
|
||||||
err = am.Store.SaveAccount(account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return setupKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore
|
|
||||||
func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) {
|
|
||||||
am.mux.Lock()
|
|
||||||
defer am.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
setupKey := getAccountSetupKeyById(account, keyId)
|
|
||||||
if setupKey == nil {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
|
|
||||||
}
|
|
||||||
|
|
||||||
keyCopy := setupKey.Copy()
|
|
||||||
keyCopy.Revoked = true
|
|
||||||
account.SetupKeys[keyCopy.Key] = keyCopy
|
|
||||||
err = am.Store.SaveAccount(account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return keyCopy, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RenameSetupKey renames existing setup key of the specified account.
|
|
||||||
func (am *DefaultAccountManager) RenameSetupKey(
|
|
||||||
accountId string,
|
|
||||||
keyId string,
|
|
||||||
newName string,
|
|
||||||
) (*SetupKey, error) {
|
|
||||||
am.mux.Lock()
|
|
||||||
defer am.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
setupKey := getAccountSetupKeyById(account, keyId)
|
|
||||||
if setupKey == nil {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
|
|
||||||
}
|
|
||||||
|
|
||||||
keyCopy := setupKey.Copy()
|
|
||||||
keyCopy.Name = newName
|
|
||||||
account.SetupKeys[keyCopy.Key] = keyCopy
|
|
||||||
err = am.Store.SaveAccount(account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return keyCopy, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
|
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
|
||||||
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
|
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
@@ -385,19 +303,25 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo {
|
|
||||||
return &UserInfo{
|
|
||||||
ID: local.Id,
|
|
||||||
Email: queried.Email,
|
|
||||||
Name: queried.Name,
|
|
||||||
Role: string(local.Role),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
|
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
|
||||||
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
|
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) lookupUserInCache(user *User, accountID string) (*idp.UserData, error) {
|
||||||
|
userData, err := am.lookupCache(map[string]*User{user.Id: user}, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, datum := range userData {
|
||||||
|
if datum.ID == user.Id {
|
||||||
|
return datum, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", user.Id)
|
||||||
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, accountID string) ([]*idp.UserData, error) {
|
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, accountID string) ([]*idp.UserData, error) {
|
||||||
data, err := am.cacheManager.Get(am.ctx, accountID)
|
data, err := am.cacheManager.Get(am.ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -437,46 +361,6 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, acco
|
|||||||
return userData, err
|
return userData, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsersFromAccount performs a batched request for users from IDP by account id
|
|
||||||
func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) {
|
|
||||||
account, err := am.GetAccountById(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
queriedUsers := make([]*idp.UserData, 0)
|
|
||||||
if !isNil(am.idpManager) {
|
|
||||||
queriedUsers, err = am.lookupCache(account.Users, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
userInfo := make([]*UserInfo, 0)
|
|
||||||
|
|
||||||
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
|
||||||
if len(queriedUsers) == 0 {
|
|
||||||
for _, user := range account.Users {
|
|
||||||
userInfo = append(userInfo, &UserInfo{
|
|
||||||
ID: user.Id,
|
|
||||||
Email: "",
|
|
||||||
Name: "",
|
|
||||||
Role: string(user.Role),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return userInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, queriedUser := range queriedUsers {
|
|
||||||
if localUser, contains := account.Users[queriedUser.ID]; contains {
|
|
||||||
userInfo = append(userInfo, mergeLocalAndQueryUser(*queriedUser, *localUser))
|
|
||||||
log.Debugf("Merged userinfo to send back; %v", userInfo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return userInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
||||||
account *Account,
|
account *Account,
|
||||||
@@ -504,7 +388,6 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
|||||||
|
|
||||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||||
//
|
//
|
||||||
//
|
|
||||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||||
@@ -688,7 +571,7 @@ func newAccountWithId(accountId, userId, domain string) *Account {
|
|||||||
|
|
||||||
setupKeys := make(map[string]*SetupKey)
|
setupKeys := make(map[string]*SetupKey)
|
||||||
defaultKey := GenerateDefaultSetupKey()
|
defaultKey := GenerateDefaultSetupKey()
|
||||||
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration)
|
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{})
|
||||||
setupKeys[defaultKey.Key] = defaultKey
|
setupKeys[defaultKey.Key] = defaultKey
|
||||||
setupKeys[oneOffKey.Key] = oneOffKey
|
setupKeys[oneOffKey.Key] = oneOffKey
|
||||||
network := NewNetwork()
|
network := NewNetwork()
|
||||||
@@ -713,15 +596,6 @@ func newAccountWithId(accountId, userId, domain string) *Account {
|
|||||||
return acc
|
return acc
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
|
|
||||||
for _, k := range acc.SetupKeys {
|
|
||||||
if keyId == k.Id {
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey {
|
func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey {
|
||||||
for _, k := range acc.SetupKeys {
|
for _, k := range acc.SetupKeys {
|
||||||
if key == k.Key {
|
if key == k.Key {
|
||||||
|
|||||||
@@ -847,7 +847,7 @@ func TestGetUsersFromAccount(t *testing.T) {
|
|||||||
account.Users[user.Id] = user
|
account.Users[user.Id] = user
|
||||||
}
|
}
|
||||||
|
|
||||||
userInfos, err := manager.GetUsersFromAccount(accountId)
|
userInfos, err := manager.GetUsers(accountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
gRPCPeer "google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -79,7 +80,10 @@ func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto
|
|||||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
log.Debugf("Sync request from peer %s", req.WgPubKey)
|
p, ok := gRPCPeer.FromContext(srv.Context())
|
||||||
|
if ok {
|
||||||
|
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
|
||||||
|
}
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -255,7 +259,10 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
|
|||||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||||
// In case of the successful registration login is also successful
|
// In case of the successful registration login is also successful
|
||||||
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
log.Debugf("Login request from peer %s", req.WgPubKey)
|
p, ok := gRPCPeer.FromContext(ctx)
|
||||||
|
if ok {
|
||||||
|
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
|
||||||
|
}
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -31,13 +31,45 @@ components:
|
|||||||
description: User's name from idp provider
|
description: User's name from idp provider
|
||||||
type: string
|
type: string
|
||||||
role:
|
role:
|
||||||
description: User's Netbird account role
|
description: User's NetBird account role
|
||||||
type: string
|
type: string
|
||||||
|
auto_groups:
|
||||||
|
description: Groups to auto-assign to peers registered by this user
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- email
|
- email
|
||||||
- name
|
- name
|
||||||
- role
|
- role
|
||||||
|
- auto_groups
|
||||||
|
UserMinimum:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
description: User ID
|
||||||
|
type: string
|
||||||
|
email:
|
||||||
|
description: User's email address
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
description: User's name from idp provider
|
||||||
|
type: string
|
||||||
|
UserRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
role:
|
||||||
|
description: User's NetBird account role
|
||||||
|
type: string
|
||||||
|
auto_groups:
|
||||||
|
description: Groups to auto-assign to peers registered by this user
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- role
|
||||||
|
- auto_groups
|
||||||
PeerMinimum:
|
PeerMinimum:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -90,6 +122,11 @@ components:
|
|||||||
ssh_enabled:
|
ssh_enabled:
|
||||||
description: Indicates whether SSH server is enabled on this peer
|
description: Indicates whether SSH server is enabled on this peer
|
||||||
type: boolean
|
type: boolean
|
||||||
|
user:
|
||||||
|
$ref: '#/components/schemas/UserMinimum'
|
||||||
|
host_name:
|
||||||
|
description: Peer's hostname
|
||||||
|
type: string
|
||||||
required:
|
required:
|
||||||
- ip
|
- ip
|
||||||
- connected
|
- connected
|
||||||
@@ -134,6 +171,15 @@ components:
|
|||||||
state:
|
state:
|
||||||
description: Setup key status, "valid", "overused","expired" or "revoked"
|
description: Setup key status, "valid", "overused","expired" or "revoked"
|
||||||
type: string
|
type: string
|
||||||
|
auto_groups:
|
||||||
|
description: Setup key groups to auto-assign to peers registered with this key
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
updated_at:
|
||||||
|
description: Setup key last update date
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- key
|
- key
|
||||||
@@ -145,6 +191,8 @@ components:
|
|||||||
- used_times
|
- used_times
|
||||||
- last_used
|
- last_used
|
||||||
- state
|
- state
|
||||||
|
- auto_groups
|
||||||
|
- updated_at
|
||||||
SetupKeyRequest:
|
SetupKeyRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -160,11 +208,17 @@ components:
|
|||||||
revoked:
|
revoked:
|
||||||
description: Setup key revocation status
|
description: Setup key revocation status
|
||||||
type: boolean
|
type: boolean
|
||||||
|
auto_groups:
|
||||||
|
description: Setup key groups to auto-assign to peers registered with this key
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
- type
|
- type
|
||||||
- expires_in
|
- expires_in
|
||||||
- revoked
|
- revoked
|
||||||
|
- auto_groups
|
||||||
GroupMinimum:
|
GroupMinimum:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -392,6 +446,40 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
/api/users/{id}:
|
||||||
|
put:
|
||||||
|
summary: Update information about a User
|
||||||
|
tags: [ Users]
|
||||||
|
security:
|
||||||
|
- BearerAuth: [ ]
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: id
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: The User ID
|
||||||
|
requestBody:
|
||||||
|
description: User update
|
||||||
|
content:
|
||||||
|
'application/json':
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/UserRequest'
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: A User object
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/User'
|
||||||
|
'400':
|
||||||
|
"$ref": "#/components/responses/bad_request"
|
||||||
|
'401':
|
||||||
|
"$ref": "#/components/responses/requires_authentication"
|
||||||
|
'403':
|
||||||
|
"$ref": "#/components/responses/forbidden"
|
||||||
|
'500':
|
||||||
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/peers:
|
/api/peers:
|
||||||
get:
|
get:
|
||||||
summary: Returns a list of all peers
|
summary: Returns a list of all peers
|
||||||
|
|||||||
@@ -137,6 +137,9 @@ type Peer struct {
|
|||||||
// Groups that the peer belongs to
|
// Groups that the peer belongs to
|
||||||
Groups []GroupMinimum `json:"groups"`
|
Groups []GroupMinimum `json:"groups"`
|
||||||
|
|
||||||
|
// Peer's hostname
|
||||||
|
HostName *string `json:"host_name,omitempty"`
|
||||||
|
|
||||||
// Peer ID
|
// Peer ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
@@ -153,7 +156,8 @@ type Peer struct {
|
|||||||
Os string `json:"os"`
|
Os string `json:"os"`
|
||||||
|
|
||||||
// Indicates whether SSH server is enabled on this peer
|
// Indicates whether SSH server is enabled on this peer
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
|
User *UserMinimum `json:"user,omitempty"`
|
||||||
|
|
||||||
// Peer's daemon or cli version
|
// Peer's daemon or cli version
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
@@ -299,6 +303,9 @@ type RulePatchOperationPath string
|
|||||||
|
|
||||||
// SetupKey defines model for SetupKey.
|
// SetupKey defines model for SetupKey.
|
||||||
type SetupKey struct {
|
type SetupKey struct {
|
||||||
|
// Setup key groups to auto-assign to peers registered with this key
|
||||||
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
// Setup Key expiration date
|
// Setup Key expiration date
|
||||||
Expires time.Time `json:"expires"`
|
Expires time.Time `json:"expires"`
|
||||||
|
|
||||||
@@ -323,6 +330,9 @@ type SetupKey struct {
|
|||||||
// Setup key type, one-off for single time usage and reusable
|
// Setup key type, one-off for single time usage and reusable
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Setup key last update date
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
// Usage count of setup key
|
// Usage count of setup key
|
||||||
UsedTimes int `json:"used_times"`
|
UsedTimes int `json:"used_times"`
|
||||||
|
|
||||||
@@ -332,6 +342,9 @@ type SetupKey struct {
|
|||||||
|
|
||||||
// SetupKeyRequest defines model for SetupKeyRequest.
|
// SetupKeyRequest defines model for SetupKeyRequest.
|
||||||
type SetupKeyRequest struct {
|
type SetupKeyRequest struct {
|
||||||
|
// Setup key groups to auto-assign to peers registered with this key
|
||||||
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
// Expiration time in seconds
|
// Expiration time in seconds
|
||||||
ExpiresIn int `json:"expires_in"`
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
|
||||||
@@ -347,6 +360,9 @@ type SetupKeyRequest struct {
|
|||||||
|
|
||||||
// User defines model for User.
|
// User defines model for User.
|
||||||
type User struct {
|
type User struct {
|
||||||
|
// Groups to auto-assign to peers registered by this user
|
||||||
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
// User's email address
|
// User's email address
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
|
||||||
@@ -356,7 +372,28 @@ type User struct {
|
|||||||
// User's name from idp provider
|
// User's name from idp provider
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
// User's Netbird account role
|
// User's NetBird account role
|
||||||
|
Role string `json:"role"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserMinimum defines model for UserMinimum.
|
||||||
|
type UserMinimum struct {
|
||||||
|
// User's email address
|
||||||
|
Email *string `json:"email,omitempty"`
|
||||||
|
|
||||||
|
// User ID
|
||||||
|
Id *string `json:"id,omitempty"`
|
||||||
|
|
||||||
|
// User's name from idp provider
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserRequest defines model for UserRequest.
|
||||||
|
type UserRequest struct {
|
||||||
|
// Groups to auto-assign to peers registered by this user
|
||||||
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
|
// User's NetBird account role
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -433,6 +470,9 @@ type PostApiSetupKeysJSONBody = SetupKeyRequest
|
|||||||
// PutApiSetupKeysIdJSONBody defines parameters for PutApiSetupKeysId.
|
// PutApiSetupKeysIdJSONBody defines parameters for PutApiSetupKeysId.
|
||||||
type PutApiSetupKeysIdJSONBody = SetupKeyRequest
|
type PutApiSetupKeysIdJSONBody = SetupKeyRequest
|
||||||
|
|
||||||
|
// PutApiUsersIdJSONBody defines parameters for PutApiUsersId.
|
||||||
|
type PutApiUsersIdJSONBody = UserRequest
|
||||||
|
|
||||||
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
|
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
|
||||||
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
|
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
|
||||||
|
|
||||||
@@ -468,3 +508,6 @@ type PostApiSetupKeysJSONRequestBody = PostApiSetupKeysJSONBody
|
|||||||
|
|
||||||
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
|
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
|
||||||
type PutApiSetupKeysIdJSONRequestBody = PutApiSetupKeysIdJSONBody
|
type PutApiSetupKeysIdJSONRequestBody = PutApiSetupKeysIdJSONBody
|
||||||
|
|
||||||
|
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
|
||||||
|
type PutApiUsersIdJSONRequestBody = PutApiUsersIdJSONBody
|
||||||
|
|||||||
@@ -39,12 +39,12 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
|
|||||||
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
|
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
|
||||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||||
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
|
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
|
||||||
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "POST", "OPTIONS")
|
apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "PUT", "OPTIONS")
|
|
||||||
|
|
||||||
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("POST", "OPTIONS")
|
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
|
||||||
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
|
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
|
||||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
|
||||||
|
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
|
||||||
|
|
||||||
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
|
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
|
||||||
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
|
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
//Peers is a handler that returns peers of the account
|
// Peers is a handler that returns peers of the account
|
||||||
type Peers struct {
|
type Peers struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
authAudience string
|
authAudience string
|
||||||
@@ -42,7 +42,7 @@ func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.Re
|
|||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeJSONObject(w, toPeerResponse(peer, account))
|
writeJSONObject(w, toPeerResponse(&server.PeerInfo{Peer: peer}, account))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -83,7 +83,7 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
h.updatePeer(account, peer, w, r)
|
h.updatePeer(account, peer, w, r)
|
||||||
return
|
return
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
writeJSONObject(w, toPeerResponse(peer, account))
|
writeJSONObject(w, toPeerResponse(&server.PeerInfo{Peer: peer}, account))
|
||||||
return
|
return
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -93,27 +93,34 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.Method {
|
|
||||||
case http.MethodGet:
|
|
||||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
respBody := []*api.Peer{}
|
if r.Method != http.MethodGet {
|
||||||
for _, peer := range account.Peers {
|
|
||||||
respBody = append(respBody, toPeerResponse(peer, account))
|
|
||||||
}
|
|
||||||
writeJSONObject(w, respBody)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
http.Error(w, "", http.StatusNotFound)
|
http.Error(w, "", http.StatusNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peers, err := h.accountManager.GetPeers(account.Id)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := []*api.Peer{}
|
||||||
|
for _, peer := range peers {
|
||||||
|
respBody = append(respBody, toPeerResponse(peer, account))
|
||||||
|
}
|
||||||
|
writeJSONObject(w, respBody)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
|
func toPeerResponse(peer *server.PeerInfo, account *server.Account) *api.Peer {
|
||||||
var groupsInfo []api.GroupMinimum
|
var groupsInfo []api.GroupMinimum
|
||||||
groupsChecked := make(map[string]struct{})
|
groupsChecked := make(map[string]struct{})
|
||||||
for _, group := range account.Groups {
|
for _, group := range account.Groups {
|
||||||
@@ -123,7 +130,7 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
|
|||||||
}
|
}
|
||||||
groupsChecked[group.ID] = struct{}{}
|
groupsChecked[group.ID] = struct{}{}
|
||||||
for _, pk := range group.Peers {
|
for _, pk := range group.Peers {
|
||||||
if pk == peer.Key {
|
if pk == peer.Peer.Key {
|
||||||
info := api.GroupMinimum{
|
info := api.GroupMinimum{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
@@ -134,15 +141,26 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &api.Peer{
|
resp := &api.Peer{
|
||||||
Id: peer.IP.String(),
|
Id: peer.Peer.IP.String(),
|
||||||
Name: peer.Name,
|
Name: peer.Peer.Name,
|
||||||
Ip: peer.IP.String(),
|
Ip: peer.Peer.IP.String(),
|
||||||
Connected: peer.Status.Connected,
|
Connected: peer.Peer.Status.Connected,
|
||||||
LastSeen: peer.Status.LastSeen,
|
LastSeen: peer.Peer.Status.LastSeen,
|
||||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, peer.Meta.Core),
|
Os: fmt.Sprintf("%s %s", peer.Peer.Meta.OS, peer.Peer.Meta.Core),
|
||||||
Version: peer.Meta.WtVersion,
|
Version: peer.Peer.Meta.WtVersion,
|
||||||
Groups: groupsInfo,
|
Groups: groupsInfo,
|
||||||
SshEnabled: peer.SSHEnabled,
|
SshEnabled: peer.Peer.SSHEnabled,
|
||||||
|
HostName: &peer.Peer.Meta.Hostname,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peer.UserInfo != nil {
|
||||||
|
resp.User = &api.UserMinimum{
|
||||||
|
Email: &peer.UserInfo.Email,
|
||||||
|
Id: &peer.UserInfo.ID,
|
||||||
|
Name: &peer.UserInfo.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -348,6 +348,11 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
err = h.accountManager.DeleteRoute(account.Id, routeID)
|
err = h.accountManager.DeleteRoute(account.Id, routeID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
errStatus, ok := status.FromError(err)
|
||||||
|
if ok && errStatus.Code() == codes.NotFound {
|
||||||
|
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
|
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -78,7 +78,10 @@ func initRoutesTestData() *Routes {
|
|||||||
SaveRouteFunc: func(_ string, _ *route.Route) error {
|
SaveRouteFunc: func(_ string, _ *route.Route) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
DeleteRouteFunc: func(_ string, _ string) error {
|
DeleteRouteFunc: func(_ string, peerIP string) error {
|
||||||
|
if peerIP != existingRouteID {
|
||||||
|
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
||||||
@@ -155,7 +158,7 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Get Not Existing Route",
|
name: "Get Not Existing Route",
|
||||||
requestType: http.MethodGet,
|
requestType: http.MethodGet,
|
||||||
requestPath: "/api/rules/" + notFoundRouteID,
|
requestPath: "/api/routes/" + notFoundRouteID,
|
||||||
expectedStatus: http.StatusNotFound,
|
expectedStatus: http.StatusNotFound,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -168,7 +171,7 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Delete Not Existing Route",
|
name: "Delete Not Existing Route",
|
||||||
requestType: http.MethodDelete,
|
requestType: http.MethodDelete,
|
||||||
requestPath: "/api/rules/" + notFoundRouteID,
|
requestPath: "/api/routes/" + notFoundRouteID,
|
||||||
expectedStatus: http.StatusNotFound,
|
expectedStatus: http.StatusNotFound,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@@ -28,54 +29,17 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
|
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
|
||||||
req := &api.PutApiSetupKeysIdJSONRequestBody{}
|
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
log.Error(err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var key *server.SetupKey
|
|
||||||
if req.Revoked {
|
|
||||||
//handle only if being revoked, don't allow to enable key again for now
|
|
||||||
key, err = h.accountManager.RevokeSetupKey(accountId, keyId)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "failed revoking key", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(req.Name) != 0 {
|
|
||||||
key, err = h.accountManager.RenameSetupKey(accountId, keyId, req.Name)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "failed renaming key", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if key != nil {
|
|
||||||
writeSuccess(w, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
|
|
||||||
account, err := h.accountManager.GetAccountById(accountId)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, key := range account.SetupKeys {
|
|
||||||
if key.Id == keyId {
|
|
||||||
writeSuccess(w, key)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
http.Error(w, "setup key not found", http.StatusNotFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.Request) {
|
|
||||||
req := &api.PostApiSetupKeysJSONRequestBody{}
|
req := &api.PostApiSetupKeysJSONRequestBody{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
@@ -95,7 +59,13 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
|
|||||||
|
|
||||||
expiresIn := time.Duration(req.ExpiresIn) * time.Second
|
expiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||||
|
|
||||||
setupKey, err := h.accountManager.AddSetupKey(accountId, req.Name, server.SetupKeyType(req.Type), expiresIn)
|
if req.AutoGroups == nil {
|
||||||
|
req.AutoGroups = []string{}
|
||||||
|
}
|
||||||
|
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||||
|
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
|
||||||
|
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||||
|
req.AutoGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStatus, ok := status.FromError(err)
|
errStatus, ok := status.FromError(err)
|
||||||
if ok && errStatus.Code() == codes.NotFound {
|
if ok && errStatus.Code() == codes.NotFound {
|
||||||
@@ -109,7 +79,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
|
|||||||
writeSuccess(w, setupKey)
|
writeSuccess(w, setupKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
|
||||||
|
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@@ -118,25 +89,84 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
keyId := vars["id"]
|
keyID := vars["id"]
|
||||||
if len(keyId) == 0 {
|
if len(keyID) == 0 {
|
||||||
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Method {
|
key, err := h.accountManager.GetSetupKey(account.Id, keyID)
|
||||||
case http.MethodPut:
|
if err != nil {
|
||||||
h.updateKey(account.Id, keyId, w, r)
|
errStatus, ok := status.FromError(err)
|
||||||
|
if ok && errStatus.Code() == codes.NotFound {
|
||||||
|
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
case http.MethodGet:
|
|
||||||
h.getKey(account.Id, keyId, w, r)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
http.Error(w, "", http.StatusNotFound)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
writeSuccess(w, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
|
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
|
||||||
|
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
keyID := vars["id"]
|
||||||
|
if len(keyID) == 0 {
|
||||||
|
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.PutApiSetupKeysIdJSONRequestBody{}
|
||||||
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.AutoGroups == nil {
|
||||||
|
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newKey := &server.SetupKey{}
|
||||||
|
newKey.AutoGroups = req.AutoGroups
|
||||||
|
newKey.Revoked = req.Revoked
|
||||||
|
newKey.Name = req.Name
|
||||||
|
newKey.Id = keyID
|
||||||
|
|
||||||
|
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if e, ok := status.FromError(err); ok {
|
||||||
|
switch e.Code() {
|
||||||
|
case codes.NotFound:
|
||||||
|
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
|
||||||
|
default:
|
||||||
|
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeSuccess(w, newKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
|
||||||
|
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -145,28 +175,18 @@ func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Method {
|
setupKeys, err := h.accountManager.ListSetupKeys(account.Id)
|
||||||
case http.MethodPost:
|
if err != nil {
|
||||||
h.createKey(account.Id, w, r)
|
log.Error(err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
case http.MethodGet:
|
|
||||||
w.WriteHeader(200)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
respBody := []*api.SetupKey{}
|
|
||||||
for _, key := range account.SetupKeys {
|
|
||||||
respBody = append(respBody, toResponseBody(key))
|
|
||||||
}
|
|
||||||
|
|
||||||
err = json.NewEncoder(w).Encode(respBody)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed encoding account peers %s: %v", account.Id, err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
http.Error(w, "", http.StatusNotFound)
|
|
||||||
}
|
}
|
||||||
|
apiSetupKeys := make([]*api.SetupKey, 0)
|
||||||
|
for _, key := range setupKeys {
|
||||||
|
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONObject(w, apiSetupKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
||||||
@@ -190,16 +210,19 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey {
|
|||||||
} else {
|
} else {
|
||||||
state = "valid"
|
state = "valid"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.SetupKey{
|
return &api.SetupKey{
|
||||||
Id: key.Id,
|
Id: key.Id,
|
||||||
Key: key.Key,
|
Key: key.Key,
|
||||||
Name: key.Name,
|
Name: key.Name,
|
||||||
Expires: key.ExpiresAt,
|
Expires: key.ExpiresAt,
|
||||||
Type: string(key.Type),
|
Type: string(key.Type),
|
||||||
Valid: key.IsValid(),
|
Valid: key.IsValid(),
|
||||||
Revoked: key.Revoked,
|
Revoked: key.Revoked,
|
||||||
UsedTimes: key.UsedTimes,
|
UsedTimes: key.UsedTimes,
|
||||||
LastUsed: key.LastUsed,
|
LastUsed: key.LastUsed,
|
||||||
State: state,
|
State: state,
|
||||||
|
AutoGroups: key.AutoGroups,
|
||||||
|
UpdatedAt: key.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
222
management/server/http/setupkeys_test.go
Normal file
222
management/server/http/setupkeys_test.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
existingSetupKeyID = "existingSetupKeyID"
|
||||||
|
newSetupKeyName = "New Setup Key"
|
||||||
|
updatedSetupKeyName = "KKKey"
|
||||||
|
notFoundSetupKeyID = "notFoundSetupKeyID"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
|
||||||
|
return &SetupKeys{
|
||||||
|
accountManager: &mock_server.MockAccountManager{
|
||||||
|
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||||
|
return &server.Account{
|
||||||
|
Id: testAccountID,
|
||||||
|
Domain: "hotmail.com",
|
||||||
|
SetupKeys: map[string]*server.SetupKey{
|
||||||
|
defaultKey.Key: defaultKey,
|
||||||
|
},
|
||||||
|
Groups: map[string]*server.Group{
|
||||||
|
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
||||||
|
"id-all": {ID: "id-all", Name: "All"}},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
|
||||||
|
if keyName == newKey.Name || typ != newKey.Type {
|
||||||
|
return newKey, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed creating setup key")
|
||||||
|
},
|
||||||
|
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) {
|
||||||
|
switch keyID {
|
||||||
|
case defaultKey.Id:
|
||||||
|
return defaultKey, nil
|
||||||
|
case newKey.Id:
|
||||||
|
return newKey, nil
|
||||||
|
default:
|
||||||
|
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
|
||||||
|
if key.Id == updatedSetupKey.Id {
|
||||||
|
return updatedSetupKey, nil
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
|
||||||
|
},
|
||||||
|
|
||||||
|
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) {
|
||||||
|
return []*server.SetupKey{defaultKey}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
authAudience: "",
|
||||||
|
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||||
|
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
|
||||||
|
return jwtclaims.AuthorizationClaims{
|
||||||
|
UserId: "test_user",
|
||||||
|
Domain: "hotmail.com",
|
||||||
|
AccountId: testAccountID,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupKeysHandlers(t *testing.T) {
|
||||||
|
defaultSetupKey := server.GenerateDefaultSetupKey()
|
||||||
|
defaultSetupKey.Id = existingSetupKeyID
|
||||||
|
|
||||||
|
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
|
||||||
|
updatedDefaultSetupKey := defaultSetupKey.Copy()
|
||||||
|
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
|
||||||
|
updatedDefaultSetupKey.Name = updatedSetupKeyName
|
||||||
|
updatedDefaultSetupKey.Revoked = true
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
requestBody io.Reader
|
||||||
|
expectedStatus int
|
||||||
|
expectedBody bool
|
||||||
|
expectedSetupKey *api.SetupKey
|
||||||
|
expectedSetupKeys []*api.SetupKey
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get Setup Keys",
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/setup-keys",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: true,
|
||||||
|
expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get Existing Setup Key",
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/setup-keys/" + existingSetupKeyID,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: true,
|
||||||
|
expectedSetupKey: toResponseBody(defaultSetupKey),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get Not Existing Setup Key",
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/setup-keys/" + notFoundSetupKeyID,
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
expectedBody: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Create Setup Key",
|
||||||
|
requestType: http.MethodPost,
|
||||||
|
requestPath: "/api/setup-keys",
|
||||||
|
requestBody: bytes.NewBuffer(
|
||||||
|
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\"}", newSetupKey.Name, newSetupKey.Type))),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: true,
|
||||||
|
expectedSetupKey: toResponseBody(newSetupKey),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update Setup Key",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/setup-keys/" + defaultSetupKey.Id,
|
||||||
|
requestBody: bytes.NewBuffer(
|
||||||
|
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"auto_groups\":[\"%s\"], \"revoked\":%v}",
|
||||||
|
updatedDefaultSetupKey.Type,
|
||||||
|
updatedDefaultSetupKey.AutoGroups[0],
|
||||||
|
updatedDefaultSetupKey.Revoked,
|
||||||
|
))),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: true,
|
||||||
|
expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/api/setup-keys", handler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/api/setup-keys/{id}", handler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/api/setup-keys/{id}", handler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("I don't know what I expected; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := recorder.Code; status != tc.expectedStatus {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||||
|
status, tc.expectedStatus, string(content))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tc.expectedBody {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectedSetupKey != nil {
|
||||||
|
got := &api.SetupKey{}
|
||||||
|
if err = json.Unmarshal(content, &got); err != nil {
|
||||||
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
assertKeys(t, got, tc.expectedSetupKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tc.expectedSetupKeys) > 0 {
|
||||||
|
var got []*api.SetupKey
|
||||||
|
if err = json.Unmarshal(content, &got); err != nil {
|
||||||
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
assertKeys(t, got[0], tc.expectedSetupKeys[0])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
|
||||||
|
// this comparison is done manually because when converting to JSON dates formatted differently
|
||||||
|
// assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work
|
||||||
|
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
|
||||||
|
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
|
||||||
|
assert.Equal(t, got.Name, expected.Name)
|
||||||
|
assert.Equal(t, got.Id, expected.Id)
|
||||||
|
assert.Equal(t, got.Key, expected.Key)
|
||||||
|
assert.Equal(t, got.Type, expected.Type)
|
||||||
|
assert.Equal(t, got.UsedTimes, expected.UsedTimes)
|
||||||
|
assert.Equal(t, got.Revoked, expected.Revoked)
|
||||||
|
assert.ElementsMatch(t, got.AutoGroups, expected.AutoGroups)
|
||||||
|
}
|
||||||
@@ -1,7 +1,12 @@
|
|||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -24,6 +29,59 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateUser is a PUT requests to update User data
|
||||||
|
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPut {
|
||||||
|
http.Error(w, "", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
userID := vars["id"]
|
||||||
|
if len(userID) == 0 {
|
||||||
|
http.Error(w, "invalid user ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.PutApiUsersIdJSONRequestBody{}
|
||||||
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userRole := server.StrRoleToUserRole(req.Role)
|
||||||
|
if userRole == server.UserRoleUnknown {
|
||||||
|
http.Error(w, "invalid user role", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newUser, err := h.accountManager.SaveUser(account.Id, &server.User{
|
||||||
|
Id: userID,
|
||||||
|
Role: userRole,
|
||||||
|
AutoGroups: req.AutoGroups,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if e, ok := status.FromError(err); ok {
|
||||||
|
switch e.Code() {
|
||||||
|
case codes.NotFound:
|
||||||
|
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
|
||||||
|
default:
|
||||||
|
http.Error(w, "failed to update user", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSONObject(w, toUserResponse(newUser))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// GetUsers returns a list of users of the account this user belongs to.
|
// GetUsers returns a list of users of the account this user belongs to.
|
||||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||||
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -34,9 +92,11 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := h.accountManager.GetUsersFromAccount(account.Id)
|
data, err := h.accountManager.GetUsers(account.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
@@ -52,10 +112,17 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toUserResponse(user *server.UserInfo) *api.User {
|
func toUserResponse(user *server.UserInfo) *api.User {
|
||||||
|
|
||||||
|
autoGroups := user.AutoGroups
|
||||||
|
if autoGroups == nil {
|
||||||
|
autoGroups = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
return &api.User{
|
return &api.User{
|
||||||
Id: user.ID,
|
Id: user.ID,
|
||||||
Name: user.Name,
|
Name: user.Name,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
Role: user.Role,
|
Role: user.Role,
|
||||||
|
AutoGroups: autoGroups,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,9 +12,8 @@ import (
|
|||||||
type MockAccountManager struct {
|
type MockAccountManager struct {
|
||||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||||
AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration) (*server.SetupKey, error)
|
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
|
||||||
RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
|
||||||
RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error)
|
|
||||||
GetAccountByIdFunc func(accountId string) (*server.Account, error)
|
GetAccountByIdFunc func(accountId string) (*server.Account, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||||
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
|
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
|
||||||
@@ -51,6 +50,9 @@ type MockAccountManager struct {
|
|||||||
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
|
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
|
||||||
DeleteRouteFunc func(accountID, routeID string) error
|
DeleteRouteFunc func(accountID, routeID string) error
|
||||||
ListRoutesFunc func(accountID string) ([]*route.Route, error)
|
ListRoutesFunc func(accountID string) ([]*route.Route, error)
|
||||||
|
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
|
||||||
|
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
|
||||||
|
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||||
@@ -58,7 +60,7 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.U
|
|||||||
if am.GetUsersFromAccountFunc != nil {
|
if am.GetUsersFromAccountFunc != nil {
|
||||||
return am.GetUsersFromAccountFunc(accountID)
|
return am.GetUsersFromAccountFunc(accountID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetUsers is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||||
@@ -82,40 +84,18 @@ func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account,
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSetupKey mock implementation of AddSetupKey from server.AccountManager interface
|
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
|
||||||
func (am *MockAccountManager) AddSetupKey(
|
func (am *MockAccountManager) CreateSetupKey(
|
||||||
accountId string,
|
accountId string,
|
||||||
keyName string,
|
keyName string,
|
||||||
keyType server.SetupKeyType,
|
keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration,
|
expiresIn time.Duration,
|
||||||
|
autoGroups []string,
|
||||||
) (*server.SetupKey, error) {
|
) (*server.SetupKey, error) {
|
||||||
if am.AddSetupKeyFunc != nil {
|
if am.CreateSetupKeyFunc != nil {
|
||||||
return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn)
|
return am.CreateSetupKeyFunc(accountId, keyName, keyType, expiresIn, autoGroups)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||||
}
|
|
||||||
|
|
||||||
// RevokeSetupKey mock implementation of RevokeSetupKey from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) RevokeSetupKey(
|
|
||||||
accountId string,
|
|
||||||
keyId string,
|
|
||||||
) (*server.SetupKey, error) {
|
|
||||||
if am.RevokeSetupKeyFunc != nil {
|
|
||||||
return am.RevokeSetupKeyFunc(accountId, keyId)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RenameSetupKey mock implementation of RenameSetupKey from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) RenameSetupKey(
|
|
||||||
accountId string,
|
|
||||||
keyId string,
|
|
||||||
newName string,
|
|
||||||
) (*server.SetupKey, error) {
|
|
||||||
if am.RenameSetupKeyFunc != nil {
|
|
||||||
return am.RenameSetupKeyFunc(accountId, keyId, newName)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey is not implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
|
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
|
||||||
@@ -415,3 +395,38 @@ func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, erro
|
|||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
|
||||||
|
if am.SaveSetupKeyFunc != nil {
|
||||||
|
return am.SaveSetupKeyFunc(accountID, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSetupKey mocks GetSetupKey of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) {
|
||||||
|
if am.GetSetupKeyFunc != nil {
|
||||||
|
return am.GetSetupKeyFunc(accountID, keyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) {
|
||||||
|
if am.ListSetupKeysFunc != nil {
|
||||||
|
return am.ListSetupKeysFunc(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveUser mocks SaveUser of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) SaveUser(accountID string, user *server.User) (*server.UserInfo, error) {
|
||||||
|
if am.SaveUserFunc != nil {
|
||||||
|
return am.SaveUserFunc(accountID, user)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,6 +31,12 @@ type PeerStatus struct {
|
|||||||
Connected bool
|
Connected bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerInfo is a composition of Peer and additional UserInfo
|
||||||
|
type PeerInfo struct {
|
||||||
|
Peer *Peer
|
||||||
|
UserInfo *UserInfo
|
||||||
|
}
|
||||||
|
|
||||||
// Peer represents a machine connected to the network.
|
// Peer represents a machine connected to the network.
|
||||||
// The Peer is a Wireguard peer identified by a public key
|
// The Peer is a Wireguard peer identified by a public key
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
@@ -68,6 +74,44 @@ func (p *Peer) Copy() *Peer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeers returns a list of Peers belonging to the specified account
|
||||||
|
func (am *DefaultAccountManager) GetPeers(accountID string) ([]*PeerInfo, error) {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := am.getUsersInfos(account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var userMap = make(map[string]*UserInfo)
|
||||||
|
for _, user := range users {
|
||||||
|
userMap[user.ID] = user
|
||||||
|
}
|
||||||
|
|
||||||
|
var peerInfos []*PeerInfo
|
||||||
|
for _, peer := range account.Peers {
|
||||||
|
if peer.UserID == "" {
|
||||||
|
peerInfos = append(peerInfos, &PeerInfo{
|
||||||
|
Peer: peer,
|
||||||
|
UserInfo: nil,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
peerInfos = append(peerInfos, &PeerInfo{
|
||||||
|
Peer: peer.Copy(),
|
||||||
|
UserInfo: userMap[peer.UserID],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerInfos, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeer returns a peer from a Store
|
// GetPeer returns a peer from a Store
|
||||||
func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
|
func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
@@ -294,6 +338,8 @@ func (am *DefaultAccountManager) AddPeer(
|
|||||||
var account *Account
|
var account *Account
|
||||||
var err error
|
var err error
|
||||||
var sk *SetupKey
|
var sk *SetupKey
|
||||||
|
// auto-assign groups that are coming with a SetupKey or a User
|
||||||
|
var groupsToAdd []string
|
||||||
if len(upperKey) != 0 {
|
if len(upperKey) != 0 {
|
||||||
account, err = am.Store.GetAccountBySetupKey(upperKey)
|
account, err = am.Store.GetAccountBySetupKey(upperKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -321,11 +367,20 @@ func (am *DefaultAccountManager) AddPeer(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupsToAdd = sk.AutoGroups
|
||||||
|
|
||||||
} else if len(userID) != 0 {
|
} else if len(userID) != 0 {
|
||||||
account, err = am.Store.GetUserAccount(userID)
|
account, err = am.Store.GetUserAccount(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
|
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
|
||||||
}
|
}
|
||||||
|
user, ok := account.Users[userID]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsToAdd = user.AutoGroups
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Empty setup key and jwt fail
|
// Empty setup key and jwt fail
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
|
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
|
||||||
@@ -361,6 +416,14 @@ func (am *DefaultAccountManager) AddPeer(
|
|||||||
}
|
}
|
||||||
group.Peers = append(group.Peers, newPeer.Key)
|
group.Peers = append(group.Peers, newPeer.Key)
|
||||||
|
|
||||||
|
if len(groupsToAdd) > 0 {
|
||||||
|
for _, s := range groupsToAdd {
|
||||||
|
if g, ok := account.Groups[s]; ok && g.Name != "All" {
|
||||||
|
g.Peers = append(g.Peers, newPeer.Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
account.Peers[newPeer.Key] = newPeer
|
account.Peers[newPeer.Key] = newPeer
|
||||||
if len(upperKey) != 0 {
|
if len(upperKey) != 0 {
|
||||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -18,8 +21,41 @@ const (
|
|||||||
DefaultSetupKeyDuration = 24 * 30 * time.Hour
|
DefaultSetupKeyDuration = 24 * 30 * time.Hour
|
||||||
// DefaultSetupKeyName is a default name of the default setup key
|
// DefaultSetupKeyName is a default name of the default setup key
|
||||||
DefaultSetupKeyName = "Default key"
|
DefaultSetupKeyName = "Default key"
|
||||||
|
|
||||||
|
// UpdateSetupKeyName indicates a setup key name update operation
|
||||||
|
UpdateSetupKeyName SetupKeyUpdateOperationType = iota
|
||||||
|
// UpdateSetupKeyRevoked indicates a setup key revoked filed update operation
|
||||||
|
UpdateSetupKeyRevoked
|
||||||
|
// UpdateSetupKeyAutoGroups indicates a setup key auto-assign groups update operation
|
||||||
|
UpdateSetupKeyAutoGroups
|
||||||
|
// UpdateSetupKeyExpiresAt indicates a setup key expiration time update operation
|
||||||
|
UpdateSetupKeyExpiresAt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetupKeyUpdateOperationType operation type
|
||||||
|
type SetupKeyUpdateOperationType int
|
||||||
|
|
||||||
|
func (t SetupKeyUpdateOperationType) String() string {
|
||||||
|
switch t {
|
||||||
|
case UpdateSetupKeyName:
|
||||||
|
return "UpdateSetupKeyName"
|
||||||
|
case UpdateSetupKeyRevoked:
|
||||||
|
return "UpdateSetupKeyRevoked"
|
||||||
|
case UpdateSetupKeyAutoGroups:
|
||||||
|
return "UpdateSetupKeyAutoGroups"
|
||||||
|
case UpdateSetupKeyExpiresAt:
|
||||||
|
return "UpdateSetupKeyExpiresAt"
|
||||||
|
default:
|
||||||
|
return "InvalidOperation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupKeyUpdateOperation operation object with type and values to be applied
|
||||||
|
type SetupKeyUpdateOperation struct {
|
||||||
|
Type SetupKeyUpdateOperationType
|
||||||
|
Values []string
|
||||||
|
}
|
||||||
|
|
||||||
// SetupKeyType is the type of setup key
|
// SetupKeyType is the type of setup key
|
||||||
type SetupKeyType string
|
type SetupKeyType string
|
||||||
|
|
||||||
@@ -31,30 +67,40 @@ type SetupKey struct {
|
|||||||
Type SetupKeyType
|
Type SetupKeyType
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
|
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
|
||||||
Revoked bool
|
Revoked bool
|
||||||
// UsedTimes indicates how many times the key was used
|
// UsedTimes indicates how many times the key was used
|
||||||
UsedTimes int
|
UsedTimes int
|
||||||
// LastUsed last time the key was used for peer registration
|
// LastUsed last time the key was used for peer registration
|
||||||
LastUsed time.Time
|
LastUsed time.Time
|
||||||
|
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
|
||||||
|
AutoGroups []string
|
||||||
}
|
}
|
||||||
|
|
||||||
//Copy copies SetupKey to a new object
|
// Copy copies SetupKey to a new object
|
||||||
func (key *SetupKey) Copy() *SetupKey {
|
func (key *SetupKey) Copy() *SetupKey {
|
||||||
|
autoGroups := make([]string, 0)
|
||||||
|
autoGroups = append(autoGroups, key.AutoGroups...)
|
||||||
|
if key.UpdatedAt.IsZero() {
|
||||||
|
key.UpdatedAt = key.CreatedAt
|
||||||
|
}
|
||||||
return &SetupKey{
|
return &SetupKey{
|
||||||
Id: key.Id,
|
Id: key.Id,
|
||||||
Key: key.Key,
|
Key: key.Key,
|
||||||
Name: key.Name,
|
Name: key.Name,
|
||||||
Type: key.Type,
|
Type: key.Type,
|
||||||
CreatedAt: key.CreatedAt,
|
CreatedAt: key.CreatedAt,
|
||||||
ExpiresAt: key.ExpiresAt,
|
ExpiresAt: key.ExpiresAt,
|
||||||
Revoked: key.Revoked,
|
UpdatedAt: key.UpdatedAt,
|
||||||
UsedTimes: key.UsedTimes,
|
Revoked: key.Revoked,
|
||||||
LastUsed: key.LastUsed,
|
UsedTimes: key.UsedTimes,
|
||||||
|
LastUsed: key.LastUsed,
|
||||||
|
AutoGroups: autoGroups,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
|
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
|
||||||
func (key *SetupKey) IncrementUsage() *SetupKey {
|
func (key *SetupKey) IncrementUsage() *SetupKey {
|
||||||
c := key.Copy()
|
c := key.Copy()
|
||||||
c.UsedTimes = c.UsedTimes + 1
|
c.UsedTimes = c.UsedTimes + 1
|
||||||
@@ -83,24 +129,25 @@ func (key *SetupKey) IsOverUsed() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSetupKey generates a new setup key
|
// GenerateSetupKey generates a new setup key
|
||||||
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration) *SetupKey {
|
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string) *SetupKey {
|
||||||
key := strings.ToUpper(uuid.New().String())
|
key := strings.ToUpper(uuid.New().String())
|
||||||
createdAt := time.Now()
|
|
||||||
return &SetupKey{
|
return &SetupKey{
|
||||||
Id: strconv.Itoa(int(Hash(key))),
|
Id: strconv.Itoa(int(Hash(key))),
|
||||||
Key: key,
|
Key: key,
|
||||||
Name: name,
|
Name: name,
|
||||||
Type: t,
|
Type: t,
|
||||||
CreatedAt: createdAt,
|
CreatedAt: time.Now(),
|
||||||
ExpiresAt: createdAt.Add(validFor),
|
ExpiresAt: time.Now().Add(validFor),
|
||||||
Revoked: false,
|
UpdatedAt: time.Now(),
|
||||||
UsedTimes: 0,
|
Revoked: false,
|
||||||
|
UsedTimes: 0,
|
||||||
|
AutoGroups: autoGroups,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateDefaultSetupKey generates a default setup key
|
// GenerateDefaultSetupKey generates a default setup key
|
||||||
func GenerateDefaultSetupKey() *SetupKey {
|
func GenerateDefaultSetupKey() *SetupKey {
|
||||||
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration)
|
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Hash(s string) uint32 {
|
func Hash(s string) uint32 {
|
||||||
@@ -111,3 +158,127 @@ func Hash(s string) uint32 {
|
|||||||
}
|
}
|
||||||
return h.Sum32()
|
return h.Sum32()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key,
|
||||||
|
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||||
|
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
|
||||||
|
expiresIn time.Duration, autoGroups []string) (*SetupKey, error) {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
keyDuration := DefaultSetupKeyDuration
|
||||||
|
if expiresIn != 0 {
|
||||||
|
keyDuration = expiresIn
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range autoGroups {
|
||||||
|
if _, ok := account.Groups[group]; !ok {
|
||||||
|
return nil, fmt.Errorf("group %s doesn't exist", group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups)
|
||||||
|
account.SetupKeys[setupKey.Key] = setupKey
|
||||||
|
|
||||||
|
err = am.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
||||||
|
}
|
||||||
|
|
||||||
|
return setupKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
|
||||||
|
// Due to the unique nature of a SetupKey certain properties must not be overwritten
|
||||||
|
// (e.g. the key itself, creation date, ID, etc).
|
||||||
|
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||||
|
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
if keyToSave == nil {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldKey *SetupKey
|
||||||
|
for _, key := range account.SetupKeys {
|
||||||
|
if key.Id == keyToSave.Id {
|
||||||
|
oldKey = key.Copy()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if oldKey == nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// only auto groups, revoked status, and name can be updated for now
|
||||||
|
newKey := oldKey.Copy()
|
||||||
|
newKey.Name = keyToSave.Name
|
||||||
|
newKey.AutoGroups = keyToSave.AutoGroups
|
||||||
|
newKey.Revoked = keyToSave.Revoked
|
||||||
|
newKey.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
account.SetupKeys[newKey.Key] = newKey
|
||||||
|
|
||||||
|
if err = am.Store.SaveAccount(account); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return newKey, am.updateAccountPeers(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSetupKeys returns a list of all setup keys of the account
|
||||||
|
func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, error) {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]*SetupKey, 0, len(account.SetupKeys))
|
||||||
|
for _, key := range account.SetupKeys {
|
||||||
|
keys = append(keys, key.Copy())
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||||
|
func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey, error) {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var foundKey *SetupKey
|
||||||
|
for _, key := range account.SetupKeys {
|
||||||
|
if key.Id == keyID {
|
||||||
|
foundKey = key.Copy()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundKey == nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
||||||
|
if foundKey.UpdatedAt.IsZero() {
|
||||||
|
foundKey.UpdatedAt = foundKey.CreatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
return foundKey, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,23 +2,159 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := "test_user"
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(userID, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.SaveGroup(account.Id, &Group{
|
||||||
|
ID: "group_1",
|
||||||
|
Name: "group_name_1",
|
||||||
|
Peers: []string{},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresIn := time.Hour
|
||||||
|
keyName := "my-test-key"
|
||||||
|
|
||||||
|
key, err := manager.CreateSetupKey(account.Id, keyName, SetupKeyReusable, expiresIn, []string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
autoGroups := []string{"group_1", "group_2"}
|
||||||
|
newKeyName := "my-new-test-key"
|
||||||
|
revoked := true
|
||||||
|
newKey, err := manager.SaveSetupKey(account.Id, &SetupKey{
|
||||||
|
Id: key.Id,
|
||||||
|
Name: newKeyName,
|
||||||
|
Revoked: revoked,
|
||||||
|
AutoGroups: autoGroups,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||||
|
key.Id, time.Now(), autoGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := "test_user"
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(userID, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.SaveGroup(account.Id, &Group{
|
||||||
|
ID: "group_1",
|
||||||
|
Name: "group_name_1",
|
||||||
|
Peers: []string{},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.SaveGroup(account.Id, &Group{
|
||||||
|
ID: "group_2",
|
||||||
|
Name: "group_name_2",
|
||||||
|
Peers: []string{},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
|
||||||
|
expectedKeyName string
|
||||||
|
expectedUsedTimes int
|
||||||
|
expectedType string
|
||||||
|
expectedGroups []string
|
||||||
|
expectedCreatedAt time.Time
|
||||||
|
expectedUpdatedAt time.Time
|
||||||
|
expectedExpiresAt time.Time
|
||||||
|
expectedFailure bool //indicates whether key creation should fail
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
expiresIn := time.Hour
|
||||||
|
testCase1 := testCase{
|
||||||
|
name: "Should Create Setup Key successfully",
|
||||||
|
expectedKeyName: "my-test-key",
|
||||||
|
expectedUsedTimes: 0,
|
||||||
|
expectedType: "reusable",
|
||||||
|
expectedGroups: []string{"group_1", "group_2"},
|
||||||
|
expectedCreatedAt: now,
|
||||||
|
expectedUpdatedAt: now,
|
||||||
|
expectedExpiresAt: now.Add(expiresIn),
|
||||||
|
expectedFailure: false,
|
||||||
|
}
|
||||||
|
testCase2 := testCase{
|
||||||
|
name: "Create Setup Key should fail because of unexistent group",
|
||||||
|
expectedKeyName: "my-test-key",
|
||||||
|
expectedGroups: []string{"FAKE"},
|
||||||
|
expectedFailure: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tCase := range []testCase{testCase1, testCase2} {
|
||||||
|
t.Run(tCase.name, func(t *testing.T) {
|
||||||
|
key, err := manager.CreateSetupKey(account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||||
|
tCase.expectedGroups)
|
||||||
|
|
||||||
|
if tCase.expectedFailure {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected to fail")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
|
||||||
|
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))),
|
||||||
|
tCase.expectedUpdatedAt, tCase.expectedGroups)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateDefaultSetupKey(t *testing.T) {
|
func TestGenerateDefaultSetupKey(t *testing.T) {
|
||||||
expectedName := "Default key"
|
expectedName := "Default key"
|
||||||
expectedRevoke := false
|
expectedRevoke := false
|
||||||
expectedType := "reusable"
|
expectedType := "reusable"
|
||||||
expectedUsedTimes := 0
|
expectedUsedTimes := 0
|
||||||
expectedCreatedAt := time.Now()
|
expectedCreatedAt := time.Now()
|
||||||
|
expectedUpdatedAt := time.Now()
|
||||||
expectedExpiresAt := time.Now().Add(24 * 30 * time.Hour)
|
expectedExpiresAt := time.Now().Add(24 * 30 * time.Hour)
|
||||||
|
var expectedAutoGroups []string
|
||||||
|
|
||||||
key := GenerateDefaultSetupKey()
|
key := GenerateDefaultSetupKey()
|
||||||
|
|
||||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
||||||
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
|
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,41 +165,44 @@ func TestGenerateSetupKey(t *testing.T) {
|
|||||||
expectedUsedTimes := 0
|
expectedUsedTimes := 0
|
||||||
expectedCreatedAt := time.Now()
|
expectedCreatedAt := time.Now()
|
||||||
expectedExpiresAt := time.Now().Add(time.Hour)
|
expectedExpiresAt := time.Now().Add(time.Hour)
|
||||||
|
expectedUpdatedAt := time.Now()
|
||||||
|
var expectedAutoGroups []string
|
||||||
|
|
||||||
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour)
|
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{})
|
||||||
|
|
||||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
|
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
||||||
|
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKey_IsValid(t *testing.T) {
|
func TestSetupKey_IsValid(t *testing.T) {
|
||||||
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour)
|
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{})
|
||||||
if !validKey.IsValid() {
|
if !validKey.IsValid() {
|
||||||
t.Errorf("expected key to be valid, got invalid %v", validKey)
|
t.Errorf("expected key to be valid, got invalid %v", validKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// expired
|
// expired
|
||||||
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour)
|
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{})
|
||||||
if expiredKey.IsValid() {
|
if expiredKey.IsValid() {
|
||||||
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
|
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// revoked
|
// revoked
|
||||||
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour)
|
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{})
|
||||||
revokedKey.Revoked = true
|
revokedKey.Revoked = true
|
||||||
if revokedKey.IsValid() {
|
if revokedKey.IsValid() {
|
||||||
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
|
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// overused
|
// overused
|
||||||
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour)
|
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{})
|
||||||
overUsedKey.UsedTimes = 1
|
overUsedKey.UsedTimes = 1
|
||||||
if overUsedKey.IsValid() {
|
if overUsedKey.IsValid() {
|
||||||
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
|
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// overused
|
// overused
|
||||||
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour)
|
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{})
|
||||||
reusableKey.UsedTimes = 99
|
reusableKey.UsedTimes = 99
|
||||||
if !reusableKey.IsValid() {
|
if !reusableKey.IsValid() {
|
||||||
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
|
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
|
||||||
@@ -71,7 +210,8 @@ func TestSetupKey_IsValid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
|
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
|
||||||
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string) {
|
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string,
|
||||||
|
expectedUpdatedAt time.Time, expectedAutoGroups []string) {
|
||||||
if key.Name != expectedName {
|
if key.Name != expectedName {
|
||||||
t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name)
|
t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name)
|
||||||
}
|
}
|
||||||
@@ -92,6 +232,10 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
|
|||||||
t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.ExpiresAt)
|
t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.ExpiresAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if key.UpdatedAt.Sub(expectedUpdatedAt).Round(time.Hour) != 0 {
|
||||||
|
t.Errorf("expected setup key to have UpdatedAt ~ %v, got %v", expectedUpdatedAt, key.UpdatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
if key.CreatedAt.Sub(expectedCreatedAt).Round(time.Hour) != 0 {
|
if key.CreatedAt.Sub(expectedCreatedAt).Round(time.Hour) != 0 {
|
||||||
t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt)
|
t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt)
|
||||||
}
|
}
|
||||||
@@ -104,13 +248,19 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
|
|||||||
if key.Id != strconv.Itoa(int(Hash(key.Key))) {
|
if key.Id != strconv.Itoa(int(Hash(key.Key))) {
|
||||||
t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id)
|
t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(key.AutoGroups) != len(expectedAutoGroups) {
|
||||||
|
t.Errorf("expected key AutoGroups size=%d, got %d", len(expectedAutoGroups), len(key.AutoGroups))
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKey_Copy(t *testing.T) {
|
func TestSetupKey_Copy(t *testing.T) {
|
||||||
|
|
||||||
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour)
|
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{})
|
||||||
keyCopy := key.Copy()
|
keyCopy := key.Copy()
|
||||||
|
|
||||||
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id)
|
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id,
|
||||||
|
key.UpdatedAt, key.AutoGroups)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,19 +2,32 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UserRoleAdmin UserRole = "admin"
|
UserRoleAdmin UserRole = "admin"
|
||||||
UserRoleUser UserRole = "user"
|
UserRoleUser UserRole = "user"
|
||||||
|
UserRoleUnknown UserRole = "unknown"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
|
||||||
|
func StrRoleToUserRole(strRole string) UserRole {
|
||||||
|
switch strings.ToLower(strRole) {
|
||||||
|
case "admin":
|
||||||
|
return UserRoleAdmin
|
||||||
|
case "user":
|
||||||
|
return UserRoleUser
|
||||||
|
default:
|
||||||
|
return UserRoleUnknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// UserRole is the role of the User
|
// UserRole is the role of the User
|
||||||
type UserRole string
|
type UserRole string
|
||||||
|
|
||||||
@@ -22,20 +35,56 @@ type UserRole string
|
|||||||
type User struct {
|
type User struct {
|
||||||
Id string
|
Id string
|
||||||
Role UserRole
|
Role UserRole
|
||||||
|
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||||
|
AutoGroups []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toUserInfo converts a User object to a UserInfo object.
|
||||||
|
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||||
|
autoGroups := u.AutoGroups
|
||||||
|
if autoGroups == nil {
|
||||||
|
autoGroups = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if userData == nil {
|
||||||
|
return &UserInfo{
|
||||||
|
ID: u.Id,
|
||||||
|
Email: "",
|
||||||
|
Name: "",
|
||||||
|
Role: string(u.Role),
|
||||||
|
AutoGroups: u.AutoGroups,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if userData.ID != u.Id {
|
||||||
|
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserInfo{
|
||||||
|
ID: u.Id,
|
||||||
|
Email: userData.Email,
|
||||||
|
Name: userData.Name,
|
||||||
|
Role: string(u.Role),
|
||||||
|
AutoGroups: autoGroups,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the user
|
||||||
func (u *User) Copy() *User {
|
func (u *User) Copy() *User {
|
||||||
|
autoGroups := []string{}
|
||||||
|
autoGroups = append(autoGroups, u.AutoGroups...)
|
||||||
return &User{
|
return &User{
|
||||||
Id: u.Id,
|
Id: u.Id,
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
|
AutoGroups: autoGroups,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUser creates a new user
|
// NewUser creates a new user
|
||||||
func NewUser(id string, role UserRole) *User {
|
func NewUser(id string, role UserRole) *User {
|
||||||
return &User{
|
return &User{
|
||||||
Id: id,
|
Id: id,
|
||||||
Role: role,
|
Role: role,
|
||||||
|
AutoGroups: []string{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,6 +98,55 @@ func NewAdminUser(id string) *User {
|
|||||||
return NewUser(id, UserRoleAdmin)
|
return NewUser(id, UserRoleAdmin)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
|
||||||
|
// Only User.AutoGroups field is allowed to be updated for now.
|
||||||
|
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
if update == nil {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newGroupID := range update.AutoGroups {
|
||||||
|
if _, ok := account.Groups[newGroupID]; !ok {
|
||||||
|
return nil,
|
||||||
|
status.Errorf(codes.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||||
|
newGroupID, update.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
oldUser := account.Users[update.Id]
|
||||||
|
if oldUser == nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "update not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// only auto groups, revoked status, and name can be updated for now
|
||||||
|
newUser := oldUser.Copy()
|
||||||
|
newUser.AutoGroups = update.AutoGroups
|
||||||
|
newUser.Role = update.Role
|
||||||
|
|
||||||
|
account.Users[newUser.Id] = newUser
|
||||||
|
|
||||||
|
if err = am.Store.SaveAccount(account); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isNil(am.idpManager) {
|
||||||
|
userData, err := am.lookupUserInCache(newUser, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newUser.toUserInfo(userData)
|
||||||
|
}
|
||||||
|
return newUser.toUserInfo(nil)
|
||||||
|
}
|
||||||
|
|
||||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||||
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
|
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
@@ -108,3 +206,50 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim
|
|||||||
|
|
||||||
return user.Role == UserRoleAdmin, nil
|
return user.Role == UserRoleAdmin, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) getUsersInfos(account *Account) ([]*UserInfo, error) {
|
||||||
|
var err error
|
||||||
|
queriedUsers := make([]*idp.UserData, 0)
|
||||||
|
if !isNil(am.idpManager) {
|
||||||
|
queriedUsers, err = am.lookupCache(account.Users, account.Id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userInfos := make([]*UserInfo, 0)
|
||||||
|
|
||||||
|
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
||||||
|
if len(queriedUsers) == 0 {
|
||||||
|
for _, user := range account.Users {
|
||||||
|
info, err := user.toUserInfo(nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userInfos = append(userInfos, info)
|
||||||
|
}
|
||||||
|
return userInfos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, queriedUser := range queriedUsers {
|
||||||
|
if localUser, contains := account.Users[queriedUser.ID]; contains {
|
||||||
|
|
||||||
|
info, err := localUser.toUserInfo(queriedUser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userInfos = append(userInfos, info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return userInfos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsers performs a batched request for users from IDP by account ID
|
||||||
|
func (am *DefaultAccountManager) GetUsers(accountID string) ([]*UserInfo, error) {
|
||||||
|
account, err := am.GetAccountById(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return am.getUsersInfos(account)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user