mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 09:46:40 +00:00
Compare commits
25 Commits
v0.18.0
...
separate_p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b43d7e8ef | ||
|
|
dcc83c8741 | ||
|
|
d56669ec2e | ||
|
|
e3d2b6a408 | ||
|
|
9f758b2015 | ||
|
|
2c50d7af1e | ||
|
|
e4c28f64fa | ||
|
|
6f2c4078ef | ||
|
|
f4ec1699ca | ||
|
|
fea53b2f0f | ||
|
|
60e6d0890a | ||
|
|
cb12e2da21 | ||
|
|
873b56f856 | ||
|
|
ecac82a5ae | ||
|
|
59372ee159 | ||
|
|
08db5f5a42 | ||
|
|
88678ef364 | ||
|
|
f1da4fd55d | ||
|
|
45224e76d0 | ||
|
|
90c8cfd863 | ||
|
|
f7196cd9a5 | ||
|
|
53d78ad982 | ||
|
|
9f352c1b7e | ||
|
|
a89808ecae | ||
|
|
c6190fa2ba |
9
.github/workflows/golang-test-linux.yml
vendored
9
.github/workflows/golang-test-linux.yml
vendored
@@ -72,6 +72,9 @@ jobs:
|
|||||||
- name: Generate Iface Test bin
|
- name: Generate Iface Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
||||||
|
|
||||||
|
- name: Generate Shared Sock Test bin
|
||||||
|
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||||
|
|
||||||
- name: Generate RouteManager Test bin
|
- name: Generate RouteManager Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||||
|
|
||||||
@@ -83,9 +86,13 @@ jobs:
|
|||||||
|
|
||||||
- run: chmod +x *testing.bin
|
- run: chmod +x *testing.bin
|
||||||
|
|
||||||
|
- name: Run Shared Sock tests in docker
|
||||||
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Iface tests in docker
|
- name: Run Iface tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
|
|
||||||
- name: Run RouteManager tests in docker
|
- name: Run RouteManager tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
@@ -93,4 +100,4 @@ jobs:
|
|||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Peer tests in docker
|
- name: Run Peer tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
@@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
@@ -45,12 +46,16 @@ var loginCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
PreSharedKey: &preSharedKey,
|
}
|
||||||
})
|
if preSharedKey != "" {
|
||||||
|
ic.PreSharedKey = &preSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := internal.UpdateOrCreateConfig(ic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
}
|
}
|
||||||
@@ -106,7 +111,7 @@ var loginCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
openURL(cmd, loginResp.VerificationURIComplete)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -185,7 +190,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||||
@@ -199,11 +204,16 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete string) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||||
|
var codeMsg string
|
||||||
|
if !strings.Contains(verificationURIComplete, userCode) {
|
||||||
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
|
}
|
||||||
|
|
||||||
err := open.Run(verificationURIComplete)
|
err := open.Run(verificationURIComplete)
|
||||||
cmd.Printf("Please do the SSO login in your browser. \n" +
|
cmd.Printf("Please do the SSO login in your browser. \n" +
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
" " + verificationURIComplete + " \n\n")
|
" " + verificationURIComplete + " " + codeMsg + " \n\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,14 +78,18 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
PreSharedKey: &preSharedKey,
|
|
||||||
NATExternalIPs: natExternalIPs,
|
NATExternalIPs: natExternalIPs,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
})
|
}
|
||||||
|
if preSharedKey != "" {
|
||||||
|
ic.PreSharedKey = &preSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := internal.UpdateOrCreateConfig(ic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
}
|
}
|
||||||
@@ -172,7 +176,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package internal
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -18,14 +19,15 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/sharedsock"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@@ -99,10 +101,8 @@ type Engine struct {
|
|||||||
|
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
|
|
||||||
udpMux ice.UDPMux
|
udpMux *bind.UniversalUDPMuxDefault
|
||||||
udpMuxSrflx ice.UniversalUDPMux
|
udpMuxConn io.Closer
|
||||||
udpMuxConn *net.UDPConn
|
|
||||||
udpMuxConnSrflx *net.UDPConn
|
|
||||||
|
|
||||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||||
networkSerial uint64
|
networkSerial uint64
|
||||||
@@ -206,33 +206,17 @@ func (e *Engine) Start() error {
|
|||||||
e.close()
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
e.udpMux = udpMux.UDPMuxDefault
|
e.udpMux = udpMux
|
||||||
e.udpMuxSrflx = udpMux
|
|
||||||
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
|
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
|
||||||
} else {
|
} else {
|
||||||
networkName := "udp"
|
rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewIncomingSTUNFilter())
|
||||||
if e.config.DisableIPv6Discovery {
|
|
||||||
networkName = "udp4"
|
|
||||||
}
|
|
||||||
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
|
||||||
e.close()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
udpMuxParams := ice.UDPMuxParams{
|
mux := bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: rawSock, Net: transportNet})
|
||||||
UDPConn: e.udpMuxConn,
|
go mux.ReadFromConn(e.ctx)
|
||||||
Net: transportNet,
|
e.udpMuxConn = rawSock
|
||||||
}
|
e.udpMux = mux
|
||||||
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
|
|
||||||
|
|
||||||
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
|
|
||||||
e.close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
||||||
@@ -262,7 +246,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
peerPubKey := p.GetWgPubKey()
|
peerPubKey := p.GetWgPubKey()
|
||||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
||||||
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") {
|
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
||||||
modified = append(modified, p)
|
modified = append(modified, p)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -394,9 +378,6 @@ func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// indicates message support in gRPC
|
|
||||||
msg.Body.FeaturesSupported = []uint32{signal.DirectCheck}
|
|
||||||
|
|
||||||
err = s.Send(msg)
|
err = s.Send(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -776,9 +757,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
|||||||
|
|
||||||
// we might have received new STUN and TURN servers meanwhile, so update them
|
// we might have received new STUN and TURN servers meanwhile, so update them
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
conf := conn.GetConf()
|
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||||
conf.StunTurn = append(e.STUNs, e.TURNs...)
|
|
||||||
conn.UpdateConf(conf)
|
|
||||||
e.syncMsgMux.Unlock()
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
err := conn.Open()
|
err := conn.Open()
|
||||||
@@ -807,9 +786,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
stunTurn = append(stunTurn, e.STUNs...)
|
stunTurn = append(stunTurn, e.STUNs...)
|
||||||
stunTurn = append(stunTurn, e.TURNs...)
|
stunTurn = append(stunTurn, e.TURNs...)
|
||||||
|
|
||||||
proxyConfig := proxy.Config{
|
wgConfig := peer.WgConfig{
|
||||||
RemoteKey: pubKey,
|
RemoteKey: pubKey,
|
||||||
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
|
WgListenPort: e.config.WgPort,
|
||||||
WgInterface: e.wgInterface,
|
WgInterface: e.wgInterface,
|
||||||
AllowedIps: allowedIPs,
|
AllowedIps: allowedIPs,
|
||||||
PreSharedKey: e.config.PreSharedKey,
|
PreSharedKey: e.config.PreSharedKey,
|
||||||
@@ -824,9 +803,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
UDPMux: e.udpMux,
|
UDPMux: e.udpMux.UDPMuxDefault,
|
||||||
UDPMuxSrflx: e.udpMuxSrflx,
|
UDPMuxSrflx: e.udpMux,
|
||||||
ProxyConfig: proxyConfig,
|
WgConfig: wgConfig,
|
||||||
LocalWgPort: e.config.WgPort,
|
LocalWgPort: e.config.WgPort,
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||||
@@ -918,18 +897,6 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
}
|
}
|
||||||
conn.OnRemoteCandidate(candidate)
|
conn.OnRemoteCandidate(candidate)
|
||||||
case sProto.Body_MODE:
|
case sProto.Body_MODE:
|
||||||
protoMode := msg.GetBody().GetMode()
|
|
||||||
if protoMode == nil {
|
|
||||||
return fmt.Errorf("received an empty mode message")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := conn.OnModeMessage(peer.ModeMessage{
|
|
||||||
Direct: protoMode.GetDirect(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed processing a mode message -> %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1020,12 +987,6 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.udpMuxConnSrflx != nil {
|
|
||||||
if err := e.udpMuxConnSrflx.Close(); err != nil {
|
|
||||||
log.Debugf("close server reflexive udp mux connection: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isNil(e.sshServer) {
|
if !isNil(e.sshServer) {
|
||||||
err := e.sshServer.Stop()
|
err := e.sshServer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -148,6 +148,11 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
|||||||
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
||||||
|
if deviceCode.VerificationURIComplete == "" {
|
||||||
|
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||||
|
}
|
||||||
|
|
||||||
return deviceCode, err
|
return deviceCode, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,10 +12,11 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
@@ -27,8 +28,18 @@ const (
|
|||||||
|
|
||||||
iceKeepAliveDefault = 4 * time.Second
|
iceKeepAliveDefault = 4 * time.Second
|
||||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||||
|
|
||||||
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WgConfig struct {
|
||||||
|
WgListenPort int
|
||||||
|
RemoteKey string
|
||||||
|
WgInterface *iface.WGIface
|
||||||
|
AllowedIps string
|
||||||
|
PreSharedKey *wgtypes.Key
|
||||||
|
}
|
||||||
|
|
||||||
// ConnConfig is a peer Connection configuration
|
// ConnConfig is a peer Connection configuration
|
||||||
type ConnConfig struct {
|
type ConnConfig struct {
|
||||||
|
|
||||||
@@ -47,7 +58,7 @@ type ConnConfig struct {
|
|||||||
|
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
|
||||||
ProxyConfig proxy.Config
|
WgConfig WgConfig
|
||||||
|
|
||||||
UDPMux ice.UDPMux
|
UDPMux ice.UDPMux
|
||||||
UDPMuxSrflx ice.UniversalUDPMux
|
UDPMuxSrflx ice.UniversalUDPMux
|
||||||
@@ -102,7 +113,7 @@ type Conn struct {
|
|||||||
|
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
|
|
||||||
proxy proxy.Proxy
|
proxy *WireGuardProxy
|
||||||
remoteModeCh chan ModeMessage
|
remoteModeCh chan ModeMessage
|
||||||
meta meta
|
meta meta
|
||||||
|
|
||||||
@@ -126,9 +137,14 @@ func (conn *Conn) GetConf() ConnConfig {
|
|||||||
return conn.config
|
return conn.config
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConf updates the connection config
|
// WgConfig returns the WireGuard config
|
||||||
func (conn *Conn) UpdateConf(conf ConnConfig) {
|
func (conn *Conn) WgConfig() WgConfig {
|
||||||
conn.config = conf
|
return conn.config.WgConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStunTurn update the turn and stun addresses
|
||||||
|
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
|
||||||
|
conn.config.StunTurn = turnStun
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
@@ -239,12 +255,12 @@ func readICEAgentConfigProperties() (time.Duration, time.Duration) {
|
|||||||
func (conn *Conn) Open() error {
|
func (conn *Conn) Open() error {
|
||||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
peerState := State{
|
||||||
|
PubKey: conn.config.Key,
|
||||||
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatusUpdate: time.Now(),
|
||||||
peerState.ConnStatus = conn.status
|
ConnStatus: conn.status,
|
||||||
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||||
@@ -299,10 +315,11 @@ func (conn *Conn) Open() error {
|
|||||||
defer conn.notifyDisconnected()
|
defer conn.notifyDisconnected()
|
||||||
conn.mu.Unlock()
|
conn.mu.Unlock()
|
||||||
|
|
||||||
peerState = State{PubKey: conn.config.Key}
|
peerState = State{
|
||||||
|
PubKey: conn.config.Key,
|
||||||
peerState.ConnStatus = conn.status
|
ConnStatus: conn.status,
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatusUpdate: time.Now(),
|
||||||
|
}
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||||
@@ -333,19 +350,12 @@ func (conn *Conn) Open() error {
|
|||||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||||
}
|
}
|
||||||
// the ice connection has been established successfully so we are ready to start the proxy
|
// the ice connection has been established successfully so we are ready to start the proxy
|
||||||
err = conn.startProxy(remoteConn, remoteWgPort)
|
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
|
log.Infof("connected to peer %s, proxy: %v, remote address: %s", conn.config.Key, conn.proxy != nil, remoteAddr.String())
|
||||||
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
|
||||||
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
|
||||||
// direct Wireguard connection
|
|
||||||
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
||||||
select {
|
select {
|
||||||
@@ -358,182 +368,81 @@ func (conn *Conn) Open() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// useProxy determines whether a direct connection (without a go proxy) is possible
|
|
||||||
//
|
|
||||||
// There are 3 cases:
|
|
||||||
//
|
|
||||||
// * When neither candidate is from hard nat and one of the peers has a public IP
|
|
||||||
//
|
|
||||||
// * both peers are in the same private network
|
|
||||||
//
|
|
||||||
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
|
|
||||||
//
|
|
||||||
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
|
|
||||||
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
|
|
||||||
|
|
||||||
if !isRelayCandidate(pair.Local) && userspaceBind {
|
|
||||||
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
|
|
||||||
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
|
|
||||||
log.Debugf("shouldn't use proxy because the remote peer is not behind a hard NAT and the local one has a public IP")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
|
|
||||||
log.Debugf("shouldn't use proxy because peers are in the same private /16 network")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isPeerReflexiveCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) ||
|
|
||||||
isHostCandidateWithPrivateIP(pair.Local) && isPeerReflexiveCandidateWithPrivateIP(pair.Remote)) && isSameNetworkPrefix(pair) {
|
|
||||||
log.Debugf("shouldn't use proxy because peers are in the same private /16 network and one peer is peer reflexive")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
|
|
||||||
|
|
||||||
localIP := net.ParseIP(pair.Local.Address())
|
|
||||||
remoteIP := net.ParseIP(pair.Remote.Address())
|
|
||||||
if localIP == nil || remoteIP == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// only consider /16 networks
|
|
||||||
mask := net.IPMask{255, 255, 0, 0}
|
|
||||||
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||||
return candidate.Type() == ice.CandidateTypeRelay
|
return candidate.Type() == ice.CandidateTypeRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
func isHardNATCandidate(candidate ice.Candidate) bool {
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
|
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
|
||||||
}
|
|
||||||
|
|
||||||
func isHostCandidateWithPublicIP(candidate ice.Candidate) bool {
|
|
||||||
return candidate.Type() == ice.CandidateTypeHost && isPublicIP(candidate.Address())
|
|
||||||
}
|
|
||||||
|
|
||||||
func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
|
|
||||||
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPeerReflexiveCandidateWithPrivateIP(candidate ice.Candidate) bool {
|
|
||||||
return candidate.Type() == ice.CandidateTypePeerReflexive && !isPublicIP(candidate.Address())
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPublicIP(address string) bool {
|
|
||||||
ip := net.ParseIP(address)
|
|
||||||
if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
|
||||||
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
var pair *ice.CandidatePair
|
|
||||||
pair, err := conn.agent.GetSelectedCandidatePair()
|
pair, err := conn.agent.GetSelectedCandidatePair()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
var endpoint net.Addr
|
||||||
p := conn.getProxyWithMessageExchange(pair, remoteWgPort)
|
if isRelayCandidate(pair.Local) {
|
||||||
conn.proxy = p
|
conn.proxy = NewWireGuardProxy(conn.config.WgConfig.WgListenPort, conn.config.WgConfig.RemoteKey, remoteConn)
|
||||||
err = p.Start(remoteConn)
|
endpoint, err = conn.proxy.Start()
|
||||||
|
if err != nil {
|
||||||
|
conn.proxy = nil
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
||||||
|
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
||||||
|
endpoint = remoteConn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpoint, conn.config.WgConfig.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if conn.proxy != nil {
|
||||||
|
_ = conn.proxy.Close()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.status = StatusConnected
|
conn.status = StatusConnected
|
||||||
|
|
||||||
peerState.ConnStatus = conn.status
|
peerState := State{
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
PubKey: conn.config.Key,
|
||||||
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
ConnStatus: conn.status,
|
||||||
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
ConnStatusUpdate: time.Now(),
|
||||||
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
|
Direct: conn.proxy == nil,
|
||||||
|
}
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||||
peerState.Relayed = true
|
peerState.Relayed = true
|
||||||
}
|
}
|
||||||
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
|
|
||||||
|
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("unable to save peer's state, got error: %v", err)
|
log.Warnf("unable to save peer's state, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
|
// wait local endpoint configuration
|
||||||
localDirectMode := !useProxy
|
time.Sleep(time.Second)
|
||||||
remoteDirectMode := localDirectMode
|
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
|
||||||
|
|
||||||
if conn.meta.protoSupport.DirectCheck {
|
|
||||||
go conn.sendLocalDirectMode(localDirectMode)
|
|
||||||
// will block until message received or timeout
|
|
||||||
remoteDirectMode = conn.receiveRemoteDirectMode()
|
|
||||||
}
|
|
||||||
|
|
||||||
if conn.config.UserspaceBind && localDirectMode {
|
|
||||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
if localDirectMode && remoteDirectMode {
|
|
||||||
return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
|
|
||||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) sendLocalDirectMode(localMode bool) {
|
|
||||||
// todo what happens when we couldn't deliver this message?
|
|
||||||
// we could retry, etc but there is no guarantee
|
|
||||||
|
|
||||||
err := conn.sendSignalMessage(&sProto.Message{
|
|
||||||
Key: conn.config.LocalKey,
|
|
||||||
RemoteKey: conn.config.Key,
|
|
||||||
Body: &sProto.Body{
|
|
||||||
Type: sProto.Body_MODE,
|
|
||||||
Mode: &sProto.Mode{
|
|
||||||
Direct: &localMode,
|
|
||||||
},
|
|
||||||
NetBirdVersion: version.NetbirdVersion(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to send local proxy mode to remote peer %s, error: %s", conn.config.Key, err)
|
log.Warnf("got an error while resolving the udp address, err: %s", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) receiveRemoteDirectMode() bool {
|
mux, ok := conn.config.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
|
||||||
timeout := time.Second
|
if !ok {
|
||||||
timer := time.NewTimer(timeout)
|
log.Warn("invalid udp mux conversion")
|
||||||
defer timer.Stop()
|
return
|
||||||
|
}
|
||||||
select {
|
_, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
|
||||||
case receivedMSG := <-conn.remoteModeCh:
|
if err != nil {
|
||||||
return receivedMSG.Direct
|
log.Warnf("got an error while sending the punch packet, err: %s", err)
|
||||||
case <-timer.C:
|
|
||||||
// we didn't receive a message from remote so we assume that it supports the direct mode to keep the old behaviour
|
|
||||||
log.Debugf("timeout after %s while waiting for remote direct mode message from remote peer %s",
|
|
||||||
timeout, conn.config.Key)
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -543,20 +452,22 @@ func (conn *Conn) cleanup() error {
|
|||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
|
var err1, err2, err3 error
|
||||||
if conn.agent != nil {
|
if conn.agent != nil {
|
||||||
err := conn.agent.Close()
|
err1 = conn.agent.Close()
|
||||||
if err != nil {
|
if err1 == nil {
|
||||||
return err
|
conn.agent = nil
|
||||||
}
|
}
|
||||||
conn.agent = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: is it problem if we try to remove a peer what is never existed?
|
||||||
|
err2 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
|
||||||
if conn.proxy != nil {
|
if conn.proxy != nil {
|
||||||
err := conn.proxy.Close()
|
err3 = conn.proxy.Close()
|
||||||
if err != nil {
|
if err3 != nil {
|
||||||
return err
|
conn.proxy = nil
|
||||||
}
|
}
|
||||||
conn.proxy = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.notifyDisconnected != nil {
|
if conn.notifyDisconnected != nil {
|
||||||
@@ -566,10 +477,11 @@ func (conn *Conn) cleanup() error {
|
|||||||
|
|
||||||
conn.status = StatusDisconnected
|
conn.status = StatusDisconnected
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
peerState := State{
|
||||||
peerState.ConnStatus = conn.status
|
PubKey: conn.config.Key,
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatus: conn.status,
|
||||||
|
ConnStatusUpdate: time.Now(),
|
||||||
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
||||||
@@ -578,8 +490,13 @@ func (conn *Conn) cleanup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
||||||
|
if err1 != nil {
|
||||||
return nil
|
return err1
|
||||||
|
}
|
||||||
|
if err2 != nil {
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
return err3
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
||||||
@@ -757,16 +674,6 @@ func (conn *Conn) GetKey() string {
|
|||||||
return conn.config.Key
|
return conn.config.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnModeMessage unmarshall the payload message and send it to the mode message channel
|
|
||||||
func (conn *Conn) OnModeMessage(message ModeMessage) error {
|
|
||||||
select {
|
|
||||||
case conn.remoteModeCh <- message:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unable to process mode message: channel busy")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterProtoSupportMeta register supported proto message in the connection metadata
|
// RegisterProtoSupportMeta register supported proto message in the connection metadata
|
||||||
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
|
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
|
||||||
protoSupport := signal.ParseFeaturesSupported(support)
|
protoSupport := signal.ParseFeaturesSupported(support)
|
||||||
|
|||||||
@@ -9,11 +9,9 @@ import (
|
|||||||
|
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/magiconair/properties/assert"
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
sproto "github.com/netbirdio/netbird/signal/proto"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var connConf = ConnConfig{
|
var connConf = ConnConfig{
|
||||||
@@ -170,310 +168,3 @@ func TestConn_Close(t *testing.T) {
|
|||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockICECandidate struct {
|
|
||||||
ice.Candidate
|
|
||||||
AddressFunc func() string
|
|
||||||
TypeFunc func() ice.CandidateType
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address mocks and overwrite ice.Candidate Address method
|
|
||||||
func (m *mockICECandidate) Address() string {
|
|
||||||
if m.AddressFunc != nil {
|
|
||||||
return m.AddressFunc()
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type mocks and overwrite ice.Candidate Type method
|
|
||||||
func (m *mockICECandidate) Type() ice.CandidateType {
|
|
||||||
if m.TypeFunc != nil {
|
|
||||||
return m.TypeFunc()
|
|
||||||
}
|
|
||||||
return ice.CandidateTypeUnspecified
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConn_ShouldUseProxy(t *testing.T) {
|
|
||||||
publicHostCandidate := &mockICECandidate{
|
|
||||||
AddressFunc: func() string {
|
|
||||||
return "8.8.8.8"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypeHost
|
|
||||||
},
|
|
||||||
}
|
|
||||||
privateHostCandidate := &mockICECandidate{
|
|
||||||
AddressFunc: func() string {
|
|
||||||
return "10.0.0.1"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypeHost
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
srflxCandidate := &mockICECandidate{
|
|
||||||
AddressFunc: func() string {
|
|
||||||
return "1.1.1.1"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypeServerReflexive
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
prflxCandidate := &mockICECandidate{
|
|
||||||
AddressFunc: func() string {
|
|
||||||
return "1.1.1.1"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypePeerReflexive
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
relayCandidate := &mockICECandidate{
|
|
||||||
AddressFunc: func() string {
|
|
||||||
return "1.1.1.1"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypeRelay
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
candatePair *ice.CandidatePair
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Use Proxy When Local Candidate Is Relay",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: relayCandidate,
|
|
||||||
Remote: privateHostCandidate,
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Use Proxy When Remote Candidate Is Relay",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: privateHostCandidate,
|
|
||||||
Remote: relayCandidate,
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Use Proxy When Local Candidate Is Peer Reflexive",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: prflxCandidate,
|
|
||||||
Remote: privateHostCandidate,
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Use Proxy When Remote Candidate Is Peer Reflexive",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: privateHostCandidate,
|
|
||||||
Remote: prflxCandidate,
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Don't Use Proxy When Local Candidate Is Public And Remote Is Private",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: publicHostCandidate,
|
|
||||||
Remote: privateHostCandidate,
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Don't Use Proxy When Remote Candidate Is Public And Local Is Private",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: privateHostCandidate,
|
|
||||||
Remote: publicHostCandidate,
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Don't Use Proxy When Local Candidate is Public And Remote Is Server Reflexive",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: publicHostCandidate,
|
|
||||||
Remote: srflxCandidate,
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Don't Use Proxy When Remote Candidate is Public And Local Is Server Reflexive",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: srflxCandidate,
|
|
||||||
Remote: publicHostCandidate,
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Don't Use Proxy When Both Candidates Are Public",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: publicHostCandidate,
|
|
||||||
Remote: publicHostCandidate,
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Don't Use Proxy When Both Candidates Are Private",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: privateHostCandidate,
|
|
||||||
Remote: privateHostCandidate,
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Don't Use Proxy When Both Candidates are in private network and one is peer reflexive",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: &mockICECandidate{AddressFunc: func() string {
|
|
||||||
return "10.16.102.168"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypeHost
|
|
||||||
}},
|
|
||||||
Remote: &mockICECandidate{AddressFunc: func() string {
|
|
||||||
return "10.16.101.96"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypePeerReflexive
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Use Proxy When Both Candidates are in private network and both are peer reflexive",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: &mockICECandidate{AddressFunc: func() string {
|
|
||||||
return "10.16.102.168"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypePeerReflexive
|
|
||||||
}},
|
|
||||||
Remote: &mockICECandidate{AddressFunc: func() string {
|
|
||||||
return "10.16.101.96"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypePeerReflexive
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
result := shouldUseProxy(testCase.candatePair, false)
|
|
||||||
if result != testCase.expected {
|
|
||||||
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetProxyWithMessageExchange(t *testing.T) {
|
|
||||||
publicHostCandidate := &mockICECandidate{
|
|
||||||
AddressFunc: func() string {
|
|
||||||
return "8.8.8.8"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypeHost
|
|
||||||
},
|
|
||||||
}
|
|
||||||
relayCandidate := &mockICECandidate{
|
|
||||||
AddressFunc: func() string {
|
|
||||||
return "1.1.1.1"
|
|
||||||
},
|
|
||||||
TypeFunc: func() ice.CandidateType {
|
|
||||||
return ice.CandidateTypeRelay
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
candatePair *ice.CandidatePair
|
|
||||||
inputDirectModeSupport bool
|
|
||||||
inputRemoteModeMessage bool
|
|
||||||
expected proxy.Type
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Should Result In Using Wireguard Proxy When Local Eval Is Use Proxy",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: relayCandidate,
|
|
||||||
Remote: publicHostCandidate,
|
|
||||||
},
|
|
||||||
inputDirectModeSupport: true,
|
|
||||||
inputRemoteModeMessage: true,
|
|
||||||
expected: proxy.TypeWireGuard,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: publicHostCandidate,
|
|
||||||
Remote: publicHostCandidate,
|
|
||||||
},
|
|
||||||
inputDirectModeSupport: true,
|
|
||||||
inputRemoteModeMessage: false,
|
|
||||||
expected: proxy.TypeWireGuard,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: relayCandidate,
|
|
||||||
Remote: publicHostCandidate,
|
|
||||||
},
|
|
||||||
inputDirectModeSupport: false,
|
|
||||||
inputRemoteModeMessage: false,
|
|
||||||
expected: proxy.TypeWireGuard,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: publicHostCandidate,
|
|
||||||
Remote: publicHostCandidate,
|
|
||||||
},
|
|
||||||
inputDirectModeSupport: false,
|
|
||||||
inputRemoteModeMessage: false,
|
|
||||||
expected: proxy.TypeDirectNoProxy,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
|
|
||||||
candatePair: &ice.CandidatePair{
|
|
||||||
Local: publicHostCandidate,
|
|
||||||
Remote: publicHostCandidate,
|
|
||||||
},
|
|
||||||
inputDirectModeSupport: true,
|
|
||||||
inputRemoteModeMessage: true,
|
|
||||||
expected: proxy.TypeDirectNoProxy,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
g := errgroup.Group{}
|
|
||||||
conn, err := NewConn(connConf, nil, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
conn.meta.protoSupport.DirectCheck = testCase.inputDirectModeSupport
|
|
||||||
conn.SetSendSignalMessage(func(message *sproto.Message) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
g.Go(func() error {
|
|
||||||
return conn.OnModeMessage(ModeMessage{
|
|
||||||
Direct: testCase.inputRemoteModeMessage,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
resultProxy := conn.getProxyWithMessageExchange(testCase.candatePair, 1000)
|
|
||||||
|
|
||||||
err = g.Wait()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if resultProxy.Type() != testCase.expected {
|
|
||||||
t.Errorf("result didn't match expected value: Expected: %s, Got: %s", testCase.expected, resultProxy.Type())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
package proxy
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
@@ -11,67 +12,45 @@ type WireGuardProxy struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
config Config
|
wgListenPort int
|
||||||
|
remoteKey string
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardProxy(config Config) *WireGuardProxy {
|
func NewWireGuardProxy(wgListenPort int, remoteKey string, remoteConn net.Conn) *WireGuardProxy {
|
||||||
p := &WireGuardProxy{config: config}
|
p := &WireGuardProxy{
|
||||||
|
wgListenPort: wgListenPort,
|
||||||
|
remoteKey: remoteKey,
|
||||||
|
remoteConn: remoteConn,
|
||||||
|
}
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireGuardProxy) updateEndpoint() error {
|
func (p *WireGuardProxy) Start() (net.Addr, error) {
|
||||||
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
lConn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", p.wgListenPort))
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// add local proxy connection as a Wireguard peer
|
|
||||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
|
||||||
udpAddr, p.config.PreSharedKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
|
|
||||||
p.remoteConn = remoteConn
|
|
||||||
|
|
||||||
var err error
|
|
||||||
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
err = p.updateEndpoint()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
p.localConn = lConn
|
||||||
|
|
||||||
go p.proxyToRemote()
|
go p.proxyToRemote()
|
||||||
go p.proxyToLocal()
|
go p.proxyToLocal()
|
||||||
|
|
||||||
return nil
|
return lConn.LocalAddr(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireGuardProxy) Close() error {
|
func (p *WireGuardProxy) Close() error {
|
||||||
p.cancel()
|
p.cancel()
|
||||||
if c := p.localConn; c != nil {
|
if p.localConn != nil {
|
||||||
err := p.localConn.Close()
|
err := p.localConn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,7 +62,7 @@ func (p *WireGuardProxy) proxyToRemote() {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey)
|
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.remoteKey)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
n, err := p.localConn.Read(buf)
|
n, err := p.localConn.Read(buf)
|
||||||
@@ -107,7 +86,7 @@ func (p *WireGuardProxy) proxyToLocal() {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey)
|
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.remoteKey)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
@@ -122,7 +101,3 @@ func (p *WireGuardProxy) proxyToLocal() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireGuardProxy) Type() Type {
|
|
||||||
return TypeWireGuard
|
|
||||||
}
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
|
|
||||||
// This is possible in either of these cases:
|
|
||||||
// - peers are in the same local network
|
|
||||||
// - one of the peers has a public static IP (host)
|
|
||||||
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
|
|
||||||
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
|
|
||||||
type DirectNoProxy struct {
|
|
||||||
config Config
|
|
||||||
// RemoteWgListenPort is a WireGuard port of a remote peer.
|
|
||||||
// It is used instead of the hardcoded 51820 port.
|
|
||||||
RemoteWgListenPort int
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
|
|
||||||
func NewDirectNoProxy(config Config, remoteWgPort int) *DirectNoProxy {
|
|
||||||
return &DirectNoProxy{config: config, RemoteWgListenPort: remoteWgPort}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close removes peer from the WireGuard interface
|
|
||||||
func (p *DirectNoProxy) Close() error {
|
|
||||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start just updates WireGuard peer with the remote IP and default WireGuard port
|
|
||||||
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
|
|
||||||
|
|
||||||
log.Debugf("using DirectNoProxy while connecting to peer %s", p.config.RemoteKey)
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
addr.Port = p.RemoteWgListenPort
|
|
||||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
|
||||||
addr, p.config.PreSharedKey)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type returns the type of this proxy
|
|
||||||
func (p *DirectNoProxy) Type() Type {
|
|
||||||
return TypeDirectNoProxy
|
|
||||||
}
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DummyProxy just sends pings to the RemoteKey peer and reads responses
|
|
||||||
type DummyProxy struct {
|
|
||||||
conn net.Conn
|
|
||||||
remote string
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDummyProxy(remote string) *DummyProxy {
|
|
||||||
p := &DummyProxy{remote: remote}
|
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Close() error {
|
|
||||||
p.cancel()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Start(remoteConn net.Conn) error {
|
|
||||||
p.conn = remoteConn
|
|
||||||
go func() {
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, err := p.conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, err := p.conn.Write([]byte("hello"))
|
|
||||||
//log.Debugf("sent ping to %s", p.remote)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Type() Type {
|
|
||||||
return TypeDummy
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NoProxy is used just to configure WireGuard without any local proxy in between.
|
|
||||||
// Used when the WireGuard interface is userspace and uses bind.ICEBind
|
|
||||||
type NoProxy struct {
|
|
||||||
config Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewNoProxy creates a new NoProxy with a provided config
|
|
||||||
func NewNoProxy(config Config) *NoProxy {
|
|
||||||
return &NoProxy{config: config}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close removes peer from the WireGuard interface
|
|
||||||
func (p *NoProxy) Close() error {
|
|
||||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start just updates WireGuard peer with the remote address
|
|
||||||
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
|
||||||
|
|
||||||
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
|
||||||
addr, p.config.PreSharedKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *NoProxy) Type() Type {
|
|
||||||
return TypeNoProxy
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DefaultWgKeepAlive = 25 * time.Second
|
|
||||||
|
|
||||||
type Type string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeDirectNoProxy Type = "DirectNoProxy"
|
|
||||||
TypeWireGuard Type = "WireGuard"
|
|
||||||
TypeDummy Type = "Dummy"
|
|
||||||
TypeNoProxy Type = "NoProxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
WgListenAddr string
|
|
||||||
RemoteKey string
|
|
||||||
WgInterface *iface.WGIface
|
|
||||||
AllowedIps string
|
|
||||||
PreSharedKey *wgtypes.Key
|
|
||||||
}
|
|
||||||
|
|
||||||
type Proxy interface {
|
|
||||||
io.Closer
|
|
||||||
// Start creates a local remoteConn and starts proxying data from/to remoteConn
|
|
||||||
Start(remoteConn net.Conn) error
|
|
||||||
Type() Type
|
|
||||||
}
|
|
||||||
8
go.mod
8
go.mod
@@ -38,12 +38,14 @@ require (
|
|||||||
github.com/gliderlabs/ssh v0.3.4
|
github.com/gliderlabs/ssh v0.3.4
|
||||||
github.com/godbus/dbus/v5 v5.1.0
|
github.com/godbus/dbus/v5 v5.1.0
|
||||||
github.com/google/go-cmp v0.5.9
|
github.com/google/go-cmp v0.5.9
|
||||||
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||||
github.com/hashicorp/go-version v1.6.0
|
github.com/hashicorp/go-version v1.6.0
|
||||||
github.com/libp2p/go-netroute v0.2.0
|
github.com/libp2p/go-netroute v0.2.0
|
||||||
github.com/magiconair/properties v1.8.5
|
github.com/magiconair/properties v1.8.5
|
||||||
github.com/mattn/go-sqlite3 v1.14.16
|
github.com/mattn/go-sqlite3 v1.14.16
|
||||||
|
github.com/mdlayher/socket v0.4.0
|
||||||
github.com/miekg/dns v1.1.43
|
github.com/miekg/dns v1.1.43
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/open-policy-agent/opa v0.49.0
|
github.com/open-policy-agent/opa v0.49.0
|
||||||
@@ -93,14 +95,12 @@ require (
|
|||||||
github.com/go-stack/stack v1.8.0 // indirect
|
github.com/go-stack/stack v1.8.0 // indirect
|
||||||
github.com/gobwas/glob v0.2.3 // indirect
|
github.com/gobwas/glob v0.2.3 // indirect
|
||||||
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
||||||
github.com/google/gopacket v1.1.19 // indirect
|
|
||||||
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
|
github.com/josharian/native v1.0.0 // indirect
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||||
github.com/mdlayher/genetlink v1.1.0 // indirect
|
github.com/mdlayher/genetlink v1.1.0 // indirect
|
||||||
github.com/mdlayher/netlink v1.4.2 // indirect
|
github.com/mdlayher/netlink v1.7.1 // indirect
|
||||||
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb // indirect
|
|
||||||
github.com/nxadm/tail v1.4.8 // indirect
|
github.com/nxadm/tail v1.4.8 // indirect
|
||||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||||
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
||||||
|
|||||||
11
go.sum
11
go.sum
@@ -380,8 +380,9 @@ github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf
|
|||||||
github.com/jackmordaunt/icns v0.0.0-20181231085925-4f16af745526/go.mod h1:UQkeMHVoNcyXYq9otUupF7/h/2tmHlhrS2zw7ZVvUqc=
|
github.com/jackmordaunt/icns v0.0.0-20181231085925-4f16af745526/go.mod h1:UQkeMHVoNcyXYq9otUupF7/h/2tmHlhrS2zw7ZVvUqc=
|
||||||
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
|
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
|
||||||
github.com/josephspurrier/goversioninfo v0.0.0-20200309025242-14b0ab84c6ca/go.mod h1:eJTEwMjXb7kZ633hO3Ln9mBUCOjX2+FlTljvpl9SYdE=
|
github.com/josephspurrier/goversioninfo v0.0.0-20200309025242-14b0ab84c6ca/go.mod h1:eJTEwMjXb7kZ633hO3Ln9mBUCOjX2+FlTljvpl9SYdE=
|
||||||
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 h1:uhL5Gw7BINiiPAo24A2sxkcDI0Jt/sqp1v5xQCniEFA=
|
|
||||||
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||||
|
github.com/josharian/native v1.0.0 h1:Ts/E8zCSEsG17dUqv7joXJFybuMLjQfWE04tsBODTxk=
|
||||||
|
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||||
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
||||||
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
|
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
|
||||||
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
|
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
|
||||||
@@ -391,7 +392,6 @@ github.com/jsimonetti/rtnetlink v0.0.0-20201220180245-69540ac93943/go.mod h1:z4c
|
|||||||
github.com/jsimonetti/rtnetlink v0.0.0-20210122163228-8d122574c736/go.mod h1:ZXpIyOK59ZnN7J0BV99cZUPmsqDRZ3eq5X+st7u/oSA=
|
github.com/jsimonetti/rtnetlink v0.0.0-20210122163228-8d122574c736/go.mod h1:ZXpIyOK59ZnN7J0BV99cZUPmsqDRZ3eq5X+st7u/oSA=
|
||||||
github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U=
|
github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U=
|
||||||
github.com/jsimonetti/rtnetlink v0.0.0-20210525051524-4cc836578190/go.mod h1:NmKSdU4VGSiv1bMsdqNALI4RSvvjtz65tTMCnD05qLo=
|
github.com/jsimonetti/rtnetlink v0.0.0-20210525051524-4cc836578190/go.mod h1:NmKSdU4VGSiv1bMsdqNALI4RSvvjtz65tTMCnD05qLo=
|
||||||
github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786 h1:N527AHMa793TP5z5GNAn/VLPzlc0ewzWdeP/25gDfgQ=
|
|
||||||
github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4hqbTdfQngbVSZJVWUhGE/lbTFf9jb+ygmNUDQMuOs=
|
github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4hqbTdfQngbVSZJVWUhGE/lbTFf9jb+ygmNUDQMuOs=
|
||||||
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||||
@@ -441,7 +441,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5
|
|||||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
|
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
||||||
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo=
|
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo=
|
||||||
github.com/mdlayher/ethtool v0.0.0-20211028163843-288d040e9d60 h1:tHdB+hQRHU10CfcK0furo6rSNgZ38JT8uPh70c/pFD8=
|
|
||||||
github.com/mdlayher/ethtool v0.0.0-20211028163843-288d040e9d60/go.mod h1:aYbhishWc4Ai3I2U4Gaa2n3kHWSwzme6EsG/46HRQbE=
|
github.com/mdlayher/ethtool v0.0.0-20211028163843-288d040e9d60/go.mod h1:aYbhishWc4Ai3I2U4Gaa2n3kHWSwzme6EsG/46HRQbE=
|
||||||
github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc=
|
github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc=
|
||||||
github.com/mdlayher/genetlink v1.1.0 h1:k2YQT3959rJOF7gOvhdfQ0lut7QMIZiuVlJANheoZ+E=
|
github.com/mdlayher/genetlink v1.1.0 h1:k2YQT3959rJOF7gOvhdfQ0lut7QMIZiuVlJANheoZ+E=
|
||||||
@@ -456,12 +455,14 @@ github.com/mdlayher/netlink v1.2.2-0.20210123213345-5cc92139ae3e/go.mod h1:bacnN
|
|||||||
github.com/mdlayher/netlink v1.3.0/go.mod h1:xK/BssKuwcRXHrtN04UBkwQ6dY9VviGGuriDdoPSWys=
|
github.com/mdlayher/netlink v1.3.0/go.mod h1:xK/BssKuwcRXHrtN04UBkwQ6dY9VviGGuriDdoPSWys=
|
||||||
github.com/mdlayher/netlink v1.4.0/go.mod h1:dRJi5IABcZpBD2A3D0Mv/AiX8I9uDEu5oGkAVrekmf8=
|
github.com/mdlayher/netlink v1.4.0/go.mod h1:dRJi5IABcZpBD2A3D0Mv/AiX8I9uDEu5oGkAVrekmf8=
|
||||||
github.com/mdlayher/netlink v1.4.1/go.mod h1:e4/KuJ+s8UhfUpO9z00/fDZZmhSrs+oxyqAS9cNgn6Q=
|
github.com/mdlayher/netlink v1.4.1/go.mod h1:e4/KuJ+s8UhfUpO9z00/fDZZmhSrs+oxyqAS9cNgn6Q=
|
||||||
github.com/mdlayher/netlink v1.4.2 h1:3sbnJWe/LETovA7yRZIX3f9McVOWV3OySH6iIBxiFfI=
|
|
||||||
github.com/mdlayher/netlink v1.4.2/go.mod h1:13VaingaArGUTUxFLf/iEovKxXji32JAtF858jZYEug=
|
github.com/mdlayher/netlink v1.4.2/go.mod h1:13VaingaArGUTUxFLf/iEovKxXji32JAtF858jZYEug=
|
||||||
|
github.com/mdlayher/netlink v1.7.1 h1:FdUaT/e33HjEXagwELR8R3/KL1Fq5x3G5jgHLp/BTmg=
|
||||||
|
github.com/mdlayher/netlink v1.7.1/go.mod h1:nKO5CSjE/DJjVhk/TNp6vCE1ktVxEA8VEh8drhZzxsQ=
|
||||||
github.com/mdlayher/socket v0.0.0-20210307095302-262dc9984e00/go.mod h1:GAFlyu4/XV68LkQKYzKhIo/WW7j3Zi0YRAz/BOoanUc=
|
github.com/mdlayher/socket v0.0.0-20210307095302-262dc9984e00/go.mod h1:GAFlyu4/XV68LkQKYzKhIo/WW7j3Zi0YRAz/BOoanUc=
|
||||||
github.com/mdlayher/socket v0.0.0-20211007213009-516dcbdf0267/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
|
github.com/mdlayher/socket v0.0.0-20211007213009-516dcbdf0267/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
|
||||||
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMVzFJ00qM25lqESg9Z4u3GuEXN5iHY=
|
|
||||||
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
|
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
|
||||||
|
github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw=
|
||||||
|
github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc=
|
||||||
github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg=
|
github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg=
|
||||||
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
|
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
|
||||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package bind
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
@@ -68,6 +69,39 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadFromConn reads from the m.params.UDPConn provided upon the creation. It expects STUN packets only, however, will
|
||||||
|
// just ignore other packets printing an warning message.
|
||||||
|
// It is a blocking method, consider running in a go routine.
|
||||||
|
func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Debugf("stopped reading from the UDPConn due to finished context")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_, a, err := m.params.UDPConn.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error while reading packet %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msg := &stun.Message{
|
||||||
|
Raw: buf,
|
||||||
|
}
|
||||||
|
err = msg.Decode()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("error while parsing STUN message. The packet doesn't seem to be a STUN packet: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.HandleSTUNMessage(msg, a)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error while handling STUn message: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||||
type udpConn struct {
|
type udpConn struct {
|
||||||
net.PacketConn
|
net.PacketConn
|
||||||
@@ -75,6 +109,11 @@ type udpConn struct {
|
|||||||
logger logging.LeveledLogger
|
logger logging.LeveledLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSharedConn returns the shared udp conn
|
||||||
|
func (m *UniversalUDPMuxDefault) GetSharedConn() net.PacketConn {
|
||||||
|
return m.params.UDPConn
|
||||||
|
}
|
||||||
|
|
||||||
// GetListenAddresses returns the listen addr of this UDP
|
// GetListenAddresses returns the listen addr of this UDP
|
||||||
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
|
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||||
return []net.Addr{m.LocalAddr()}
|
return []net.Addr{m.LocalAddr()}
|
||||||
|
|||||||
@@ -77,12 +77,17 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
|
|
||||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||||
// Endpoint is optional
|
// Endpoint is optional
|
||||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint net.Addr, preSharedKey *wgtypes.Key) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
rAddr, err := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||||
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updating interface %s peer %s, endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
||||||
|
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, rAddr, preSharedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ func (c *wGConfigurer) configure(config wgtypes.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Debugf("got Wireguard device %s", c.deviceName)
|
log.Tracef("got Wireguard device %s", c.deviceName)
|
||||||
|
|
||||||
return wg.ConfigureDevice(c.deviceName, config)
|
return wg.ConfigureDevice(c.deviceName, config)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ var (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
|
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
|
||||||
}
|
}
|
||||||
|
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
|
||||||
|
|
||||||
tlsEnabled := false
|
tlsEnabled := false
|
||||||
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
|
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
|
||||||
@@ -186,6 +187,7 @@ var (
|
|||||||
config.HttpConfig.AuthIssuer,
|
config.HttpConfig.AuthIssuer,
|
||||||
config.GetAuthAudiences(),
|
config.GetAuthAudiences(),
|
||||||
config.HttpConfig.AuthKeysLocation,
|
config.HttpConfig.AuthKeysLocation,
|
||||||
|
config.HttpConfig.IdpSignKeyRefreshEnabled,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed creating JWT validator: %v", err)
|
return fmt.Errorf("failed creating JWT validator: %v", err)
|
||||||
@@ -394,6 +396,12 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
|||||||
}
|
}
|
||||||
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||||
|
|
||||||
|
log.Infof("configuring IdpManagerConfig.OIDCConfig.Issuer with a new value %s,", oidcConfig.Issuer)
|
||||||
|
config.IdpManagerConfig.OIDCConfig.Issuer = strings.TrimRight(oidcConfig.Issuer, "/")
|
||||||
|
|
||||||
|
log.Infof("configuring IdpManagerConfig.OIDCConfig.TokenEndpoint with a new value %s,", oidcConfig.TokenEndpoint)
|
||||||
|
config.IdpManagerConfig.OIDCConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||||
|
|
||||||
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||||
oidcConfig.Issuer, config.HttpConfig.AuthIssuer)
|
oidcConfig.Issuer, config.HttpConfig.AuthIssuer)
|
||||||
config.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
config.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
||||||
@@ -439,7 +447,7 @@ type OIDCConfigResponse struct {
|
|||||||
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||||
res, err := http.Get(oidcEndpoint)
|
res, err := http.Get(oidcEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err)
|
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -16,13 +16,14 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
dnsDomain string
|
dnsDomain string
|
||||||
mgmtDataDir string
|
mgmtDataDir string
|
||||||
mgmtConfig string
|
mgmtConfig string
|
||||||
logLevel string
|
logLevel string
|
||||||
logFile string
|
logFile string
|
||||||
disableMetrics bool
|
disableMetrics bool
|
||||||
disableSingleAccMode bool
|
disableSingleAccMode bool
|
||||||
|
idpSignKeyRefreshEnabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird-mgmt",
|
Use: "netbird-mgmt",
|
||||||
@@ -54,6 +55,7 @@ func init() {
|
|||||||
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||||
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
|
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
|
||||||
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
|
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
|
||||||
|
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
|
||||||
rootCmd.MarkFlagRequired("config") //nolint
|
rootCmd.MarkFlagRequired("config") //nolint
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")
|
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")
|
||||||
|
|||||||
@@ -49,17 +49,16 @@ type AccountManager interface {
|
|||||||
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||||
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
|
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
|
||||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||||
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
|
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||||
DeleteUser(accountID, executingUserID string, targetUserID string) error
|
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
||||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||||
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
|
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||||
GetAccountByUserID(userID string) (*Account, error)
|
|
||||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
MarkPATUsed(tokenID string) error
|
MarkPATUsed(tokenID string) error
|
||||||
IsUserAdmin(userID string) (bool, error)
|
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||||
AccountExists(accountId string) (*bool, error)
|
AccountExists(accountId string) (*bool, error)
|
||||||
GetPeerByKey(peerKey string) (*Peer, error)
|
GetPeerByKey(peerKey string) (*Peer, error)
|
||||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||||
@@ -70,10 +69,10 @@ type AccountManager interface {
|
|||||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||||
GetPeerNetwork(peerID string) (*Network, error)
|
GetPeerNetwork(peerID string) (*Network, error)
|
||||||
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
||||||
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
||||||
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
|
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||||
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||||
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||||
UpdatePeerSSHKey(peerID string, sshKey string) error
|
UpdatePeerSSHKey(peerID string, sshKey string) error
|
||||||
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
||||||
GetGroup(accountId, groupID string) (*Group, error)
|
GetGroup(accountId, groupID string) (*Group, error)
|
||||||
@@ -180,6 +179,7 @@ type UserInfo struct {
|
|||||||
AutoGroups []string `json:"auto_groups"`
|
AutoGroups []string `json:"auto_groups"`
|
||||||
Status string `json:"-"`
|
Status string `json:"-"`
|
||||||
IsServiceUser bool `json:"is_service_user"`
|
IsServiceUser bool `json:"is_service_user"`
|
||||||
|
IsBlocked bool `json:"is_blocked"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||||
@@ -903,7 +903,9 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
|
|||||||
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
||||||
users := make(map[string]struct{}, len(account.Users))
|
users := make(map[string]struct{}, len(account.Users))
|
||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
users[user.Id] = struct{}{}
|
if !user.IsServiceUser {
|
||||||
|
users[user.Id] = struct{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||||
userData, err := am.lookupCache(users, account.Id)
|
userData, err := am.lookupCache(users, account.Id)
|
||||||
|
|||||||
@@ -91,6 +91,10 @@ const (
|
|||||||
ServiceUserCreated
|
ServiceUserCreated
|
||||||
// ServiceUserDeleted indicates that a user deleted a service user
|
// ServiceUserDeleted indicates that a user deleted a service user
|
||||||
ServiceUserDeleted
|
ServiceUserDeleted
|
||||||
|
// UserBlocked indicates that a user blocked another user
|
||||||
|
UserBlocked
|
||||||
|
// UserUnblocked indicates that a user unblocked another user
|
||||||
|
UserUnblocked
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -184,6 +188,10 @@ const (
|
|||||||
ServiceUserCreatedMessage string = "Service user created"
|
ServiceUserCreatedMessage string = "Service user created"
|
||||||
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
|
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
|
||||||
ServiceUserDeletedMessage string = "Service user deleted"
|
ServiceUserDeletedMessage string = "Service user deleted"
|
||||||
|
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
|
||||||
|
UserBlockedMessage string = "User blocked"
|
||||||
|
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
||||||
|
UserUnblockedMessage string = "User unblocked"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Activity that triggered an Event
|
// Activity that triggered an Event
|
||||||
@@ -282,6 +290,10 @@ func (a Activity) Message() string {
|
|||||||
return ServiceUserCreatedMessage
|
return ServiceUserCreatedMessage
|
||||||
case ServiceUserDeleted:
|
case ServiceUserDeleted:
|
||||||
return ServiceUserDeletedMessage
|
return ServiceUserDeletedMessage
|
||||||
|
case UserBlocked:
|
||||||
|
return UserBlockedMessage
|
||||||
|
case UserUnblocked:
|
||||||
|
return UserUnblockedMessage
|
||||||
default:
|
default:
|
||||||
return "UNKNOWN_ACTIVITY"
|
return "UNKNOWN_ACTIVITY"
|
||||||
}
|
}
|
||||||
@@ -300,6 +312,10 @@ func (a Activity) StringCode() string {
|
|||||||
return "user.join"
|
return "user.join"
|
||||||
case UserInvited:
|
case UserInvited:
|
||||||
return "user.invite"
|
return "user.invite"
|
||||||
|
case UserBlocked:
|
||||||
|
return "user.block"
|
||||||
|
case UserUnblocked:
|
||||||
|
return "user.unblock"
|
||||||
case AccountCreated:
|
case AccountCreated:
|
||||||
return "account.create"
|
return "account.create"
|
||||||
case RuleAdded:
|
case RuleAdded:
|
||||||
|
|||||||
@@ -80,6 +80,8 @@ type HttpServerConfig struct {
|
|||||||
AuthKeysLocation string
|
AuthKeysLocation string
|
||||||
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
|
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
|
||||||
OIDCConfigEndpoint string
|
OIDCConfigEndpoint string
|
||||||
|
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
|
||||||
|
IdpSignKeyRefreshEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)
|
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)
|
||||||
|
|||||||
@@ -52,7 +52,9 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
|||||||
jwtValidator, err = jwtclaims.NewJWTValidator(
|
jwtValidator, err = jwtclaims.NewJWTValidator(
|
||||||
config.HttpConfig.AuthIssuer,
|
config.HttpConfig.AuthIssuer,
|
||||||
config.GetAuthAudiences(),
|
config.GetAuthAudiences(),
|
||||||
config.HttpConfig.AuthKeysLocation)
|
config.HttpConfig.AuthKeysLocation,
|
||||||
|
config.HttpConfig.IdpSignKeyRefreshEnabled,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,13 +59,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
accountID := vars["id"]
|
accountID := vars["accountId"]
|
||||||
if len(accountID) == 0 {
|
if len(accountID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiAccountsIdJSONBody
|
var req api.PutApiAccountsAccountIdJSONBody
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET")
|
router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET")
|
||||||
router.HandleFunc("/api/accounts/{id}", handler.UpdateAccount).Methods("PUT")
|
router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ info:
|
|||||||
tags:
|
tags:
|
||||||
- name: Users
|
- name: Users
|
||||||
description: Interact with and view information about users.
|
description: Interact with and view information about users.
|
||||||
- name: Tokens
|
|
||||||
description: Interact with and view information about tokens.
|
|
||||||
- name: Peers
|
- name: Peers
|
||||||
description: Interact with and view information about peers.
|
description: Interact with and view information about peers.
|
||||||
- name: Setup Keys
|
- name: Setup Keys
|
||||||
@@ -67,7 +65,7 @@ components:
|
|||||||
status:
|
status:
|
||||||
description: User's status
|
description: User's status
|
||||||
type: string
|
type: string
|
||||||
enum: [ "active","invited","disabled" ]
|
enum: [ "active","invited","blocked" ]
|
||||||
auto_groups:
|
auto_groups:
|
||||||
description: Groups to auto-assign to peers registered by this user
|
description: Groups to auto-assign to peers registered by this user
|
||||||
type: array
|
type: array
|
||||||
@@ -81,6 +79,9 @@ components:
|
|||||||
description: Is true if this user is a service user
|
description: Is true if this user is a service user
|
||||||
type: boolean
|
type: boolean
|
||||||
readOnly: true
|
readOnly: true
|
||||||
|
is_blocked:
|
||||||
|
description: Is true if this user is blocked. Blocked users can't use the system
|
||||||
|
type: boolean
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- email
|
- email
|
||||||
@@ -88,6 +89,7 @@ components:
|
|||||||
- role
|
- role
|
||||||
- auto_groups
|
- auto_groups
|
||||||
- status
|
- status
|
||||||
|
- is_blocked
|
||||||
UserRequest:
|
UserRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -99,21 +101,25 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
|
is_blocked:
|
||||||
|
description: If set to true then user is blocked and can't use the system
|
||||||
|
type: boolean
|
||||||
required:
|
required:
|
||||||
- role
|
- role
|
||||||
- auto_groups
|
- auto_groups
|
||||||
|
- is_blocked
|
||||||
UserCreateRequest:
|
UserCreateRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
role:
|
|
||||||
description: User's NetBird account role
|
|
||||||
type: string
|
|
||||||
email:
|
email:
|
||||||
description: User's Email to send invite to
|
description: User's Email to send invite to
|
||||||
type: string
|
type: string
|
||||||
name:
|
name:
|
||||||
description: User's full name
|
description: User's full name
|
||||||
type: string
|
type: string
|
||||||
|
role:
|
||||||
|
description: User's NetBird account role
|
||||||
|
type: string
|
||||||
auto_groups:
|
auto_groups:
|
||||||
description: Groups to auto-assign to peers registered by this user
|
description: Groups to auto-assign to peers registered by this user
|
||||||
type: array
|
type: array
|
||||||
@@ -343,6 +349,8 @@ components:
|
|||||||
expires_in:
|
expires_in:
|
||||||
description: Expiration in days
|
description: Expiration in days
|
||||||
type: integer
|
type: integer
|
||||||
|
minimum: 1
|
||||||
|
maximum: 365
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
- expires_in
|
- expires_in
|
||||||
@@ -374,33 +382,6 @@ components:
|
|||||||
$ref: '#/components/schemas/PeerMinimum'
|
$ref: '#/components/schemas/PeerMinimum'
|
||||||
required:
|
required:
|
||||||
- peers
|
- peers
|
||||||
PatchMinimum:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
op:
|
|
||||||
description: Patch operation type
|
|
||||||
type: string
|
|
||||||
enum: [ "replace","add","remove" ]
|
|
||||||
value:
|
|
||||||
description: Values to be applied
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- op
|
|
||||||
- value
|
|
||||||
GroupPatchOperation:
|
|
||||||
allOf:
|
|
||||||
- $ref: '#/components/schemas/PatchMinimum'
|
|
||||||
- type: object
|
|
||||||
properties:
|
|
||||||
path:
|
|
||||||
description: Group field to update in form /<field>
|
|
||||||
type: string
|
|
||||||
enum: [ "name","peers" ]
|
|
||||||
required:
|
|
||||||
- path
|
|
||||||
|
|
||||||
RuleMinimum:
|
RuleMinimum:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -446,17 +427,6 @@ components:
|
|||||||
required:
|
required:
|
||||||
- sources
|
- sources
|
||||||
- destinations
|
- destinations
|
||||||
RulePatchOperation:
|
|
||||||
allOf:
|
|
||||||
- $ref: '#/components/schemas/PatchMinimum'
|
|
||||||
- type: object
|
|
||||||
properties:
|
|
||||||
path:
|
|
||||||
description: Rule field to update in form /<field>
|
|
||||||
type: string
|
|
||||||
enum: [ "name","description","disabled","flow","sources","destinations" ]
|
|
||||||
required:
|
|
||||||
- path
|
|
||||||
PolicyRule:
|
PolicyRule:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -585,17 +555,6 @@ components:
|
|||||||
- id
|
- id
|
||||||
- network_type
|
- network_type
|
||||||
- $ref: '#/components/schemas/RouteRequest'
|
- $ref: '#/components/schemas/RouteRequest'
|
||||||
RoutePatchOperation:
|
|
||||||
allOf:
|
|
||||||
- $ref: '#/components/schemas/PatchMinimum'
|
|
||||||
- type: object
|
|
||||||
properties:
|
|
||||||
path:
|
|
||||||
description: Route field to update in form /<field>
|
|
||||||
type: string
|
|
||||||
enum: [ "network","network_id","description","enabled","peer","metric","masquerade", "groups" ]
|
|
||||||
required:
|
|
||||||
- path
|
|
||||||
Nameserver:
|
Nameserver:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -667,17 +626,6 @@ components:
|
|||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- $ref: '#/components/schemas/NameserverGroupRequest'
|
- $ref: '#/components/schemas/NameserverGroupRequest'
|
||||||
NameserverGroupPatchOperation:
|
|
||||||
allOf:
|
|
||||||
- $ref: '#/components/schemas/PatchMinimum'
|
|
||||||
- type: object
|
|
||||||
properties:
|
|
||||||
path:
|
|
||||||
description: Nameserver group field to update in form /<field>
|
|
||||||
type: string
|
|
||||||
enum: [ "name", "description", "enabled", "groups", "nameservers", "primary", "domains" ]
|
|
||||||
required:
|
|
||||||
- path
|
|
||||||
DNSSettings:
|
DNSSettings:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -705,7 +653,7 @@ components:
|
|||||||
description: The string code of the activity that occurred during the event
|
description: The string code of the activity that occurred during the event
|
||||||
type: string
|
type: string
|
||||||
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
|
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
|
||||||
"user.role.update",
|
"user.role.update", "user.block", "user.unblock",
|
||||||
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
|
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
|
||||||
"setupkey.group.delete", "setupkey.group.add",
|
"setupkey.group.delete", "setupkey.group.add",
|
||||||
"rule.add", "rule.delete", "rule.update",
|
"rule.add", "rule.delete", "rule.update",
|
||||||
@@ -761,15 +709,23 @@ components:
|
|||||||
type: http
|
type: http
|
||||||
scheme: bearer
|
scheme: bearer
|
||||||
bearerFormat: JWT
|
bearerFormat: JWT
|
||||||
|
TokenAuth:
|
||||||
|
type: apiKey
|
||||||
|
in: header
|
||||||
|
name: Authorization
|
||||||
|
description: >-
|
||||||
|
Enter the token with the `Token` prefix, e.g. "Token nbp_F3f0d.....".
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
paths:
|
paths:
|
||||||
/api/accounts:
|
/api/accounts:
|
||||||
get:
|
get:
|
||||||
summary: Returns a list of accounts of a user. Always returns a list of one account. Only available for admin users.
|
summary: Returns a list of accounts of a user. Always returns a list of one account.
|
||||||
tags: [ Accounts ]
|
tags: [ Accounts ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON array of accounts
|
description: A JSON array of accounts
|
||||||
@@ -787,19 +743,20 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/accounts/{id}:
|
/api/accounts/{accountId}:
|
||||||
put:
|
put:
|
||||||
summary: Update information about an account
|
summary: Update information about an account
|
||||||
tags: [ Accounts ]
|
tags: [ Accounts ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: accountId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Account ID
|
description: The unique identifier of an account
|
||||||
requestBody:
|
requestBody:
|
||||||
description: update an account
|
description: update an account
|
||||||
content:
|
content:
|
||||||
@@ -832,12 +789,13 @@ paths:
|
|||||||
tags: [ Users ]
|
tags: [ Users ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: query
|
- in: query
|
||||||
name: service_user
|
name: service_user
|
||||||
schema:
|
schema:
|
||||||
type: boolean
|
type: boolean
|
||||||
description: Filters users and returns either normal users or service users
|
description: Filters users and returns either regular users or service users
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON array of Users
|
description: A JSON array of Users
|
||||||
@@ -855,12 +813,12 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/users/:
|
|
||||||
post:
|
post:
|
||||||
summary: Create a User (invite)
|
summary: Create a User (or invite)
|
||||||
tags: [ Users ]
|
tags: [ Users ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
requestBody:
|
requestBody:
|
||||||
description: User invite information
|
description: User invite information
|
||||||
content:
|
content:
|
||||||
@@ -882,19 +840,20 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/users/{id}:
|
/api/users/{userId}:
|
||||||
put:
|
put:
|
||||||
summary: Update information about a User
|
summary: Update information about a User
|
||||||
tags: [ Users ]
|
tags: [ Users ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: userId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The User ID
|
description: The unique identifier of a user
|
||||||
requestBody:
|
requestBody:
|
||||||
description: User update
|
description: User update
|
||||||
content:
|
content:
|
||||||
@@ -923,11 +882,11 @@ paths:
|
|||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: userId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The User ID
|
description: The unique identifier of a user
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: Delete status code
|
description: Delete status code
|
||||||
@@ -943,16 +902,17 @@ paths:
|
|||||||
/api/users/{userId}/tokens:
|
/api/users/{userId}/tokens:
|
||||||
get:
|
get:
|
||||||
summary: Returns a list of all tokens for a user
|
summary: Returns a list of all tokens for a user
|
||||||
tags: [ Tokens ]
|
tags: [ Users ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: []
|
- BearerAuth: []
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: userId
|
name: userId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The User ID
|
description: The unique identifier of a user
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of PersonalAccessTokens
|
description: A JSON Array of PersonalAccessTokens
|
||||||
@@ -971,17 +931,18 @@ paths:
|
|||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
post:
|
post:
|
||||||
summary: Create a new token
|
summary: Create a new token for a user
|
||||||
tags: [ Tokens ]
|
tags: [ Users ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: userId
|
name: userId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The User ID
|
description: The unique identifier of a user
|
||||||
requestBody:
|
requestBody:
|
||||||
description: PersonalAccessToken create parameters
|
description: PersonalAccessToken create parameters
|
||||||
content:
|
content:
|
||||||
@@ -1005,23 +966,24 @@ paths:
|
|||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/users/{userId}/tokens/{tokenId}:
|
/api/users/{userId}/tokens/{tokenId}:
|
||||||
get:
|
get:
|
||||||
summary: Returns a specific token
|
summary: Returns a specific token for a user
|
||||||
tags: [ Tokens ]
|
tags: [ Users ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: userId
|
name: userId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The User ID
|
description: The unique identifier of a user
|
||||||
- in: path
|
- in: path
|
||||||
name: tokenId
|
name: tokenId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Token ID
|
description: The unique identifier of a token
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A PersonalAccessTokens Object
|
description: A PersonalAccessTokens Object
|
||||||
@@ -1038,23 +1000,24 @@ paths:
|
|||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
delete:
|
delete:
|
||||||
summary: Delete a token
|
summary: Delete a token for a user
|
||||||
tags: [ Tokens ]
|
tags: [ Users ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: userId
|
name: userId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The User ID
|
description: The unique identifier of a user
|
||||||
- in: path
|
- in: path
|
||||||
name: tokenId
|
name: tokenId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Token ID
|
description: The unique identifier of a token
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: Delete status code
|
description: Delete status code
|
||||||
@@ -1073,6 +1036,7 @@ paths:
|
|||||||
tags: [ Peers ]
|
tags: [ Peers ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of Peers
|
description: A JSON Array of Peers
|
||||||
@@ -1090,19 +1054,20 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/peers/{id}:
|
/api/peers/{peerId}:
|
||||||
get:
|
get:
|
||||||
summary: Get information about a peer
|
summary: Get information about a peer
|
||||||
tags: [ Peers ]
|
tags: [ Peers ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: peerId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Peer ID
|
description: The unique identifier of a peer
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A Peer object
|
description: A Peer object
|
||||||
@@ -1123,13 +1088,14 @@ paths:
|
|||||||
tags: [ Peers ]
|
tags: [ Peers ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: peerId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Peer ID
|
description: The unique identifier of a peer
|
||||||
requestBody:
|
requestBody:
|
||||||
description: update a peer
|
description: update a peer
|
||||||
content:
|
content:
|
||||||
@@ -1167,13 +1133,14 @@ paths:
|
|||||||
tags: [ Peers ]
|
tags: [ Peers ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: peerId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Peer ID
|
description: The unique identifier of a peer
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: Delete status code
|
description: Delete status code
|
||||||
@@ -1192,6 +1159,7 @@ paths:
|
|||||||
tags: [ Setup Keys ]
|
tags: [ Setup Keys ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of Setup keys
|
description: A JSON Array of Setup keys
|
||||||
@@ -1214,6 +1182,7 @@ paths:
|
|||||||
tags: [ Setup Keys ]
|
tags: [ Setup Keys ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
requestBody:
|
requestBody:
|
||||||
description: New Setup Key request
|
description: New Setup Key request
|
||||||
content:
|
content:
|
||||||
@@ -1235,19 +1204,20 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/setup-keys/{id}:
|
/api/setup-keys/{keyId}:
|
||||||
get:
|
get:
|
||||||
summary: Get information about a Setup Key
|
summary: Get information about a Setup Key
|
||||||
tags: [ Setup Keys ]
|
tags: [ Setup Keys ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: keyId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Setup Key ID
|
description: The unique identifier of a setup key
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A Setup Key object
|
description: A Setup Key object
|
||||||
@@ -1268,13 +1238,14 @@ paths:
|
|||||||
tags: [ Setup Keys ]
|
tags: [ Setup Keys ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: keyId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Setup Key ID
|
description: The unique identifier of a setup key
|
||||||
requestBody:
|
requestBody:
|
||||||
description: update to Setup Key
|
description: update to Setup Key
|
||||||
content:
|
content:
|
||||||
@@ -1296,36 +1267,13 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
delete:
|
|
||||||
summary: Delete a Setup Key
|
|
||||||
tags: [ Setup Keys ]
|
|
||||||
security:
|
|
||||||
- BearerAuth: [ ]
|
|
||||||
parameters:
|
|
||||||
- in: path
|
|
||||||
name: id
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
description: The Setup Key ID
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Delete status code
|
|
||||||
content: { }
|
|
||||||
'400':
|
|
||||||
"$ref": "#/components/responses/bad_request"
|
|
||||||
'401':
|
|
||||||
"$ref": "#/components/responses/requires_authentication"
|
|
||||||
'403':
|
|
||||||
"$ref": "#/components/responses/forbidden"
|
|
||||||
'500':
|
|
||||||
"$ref": "#/components/responses/internal_error"
|
|
||||||
/api/groups:
|
/api/groups:
|
||||||
get:
|
get:
|
||||||
summary: Returns a list of all Groups
|
summary: Returns a list of all Groups
|
||||||
tags: [ Groups ]
|
tags: [ Groups ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of Groups
|
description: A JSON Array of Groups
|
||||||
@@ -1348,6 +1296,7 @@ paths:
|
|||||||
tags: [ Groups ]
|
tags: [ Groups ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
requestBody:
|
requestBody:
|
||||||
description: New Group request
|
description: New Group request
|
||||||
content:
|
content:
|
||||||
@@ -1378,19 +1327,20 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/groups/{id}:
|
/api/groups/{groupId}:
|
||||||
get:
|
get:
|
||||||
summary: Get information about a Group
|
summary: Get information about a Group
|
||||||
tags: [ Groups ]
|
tags: [ Groups ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: groupId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Group ID
|
description: The unique identifier of a group
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A Group object
|
description: A Group object
|
||||||
@@ -1411,13 +1361,14 @@ paths:
|
|||||||
tags: [ Groups ]
|
tags: [ Groups ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: groupId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Group ID
|
description: The unique identifier of a group
|
||||||
requestBody:
|
requestBody:
|
||||||
description: Update Group request
|
description: Update Group request
|
||||||
content:
|
content:
|
||||||
@@ -1446,53 +1397,19 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
patch:
|
|
||||||
summary: Update information about a Group
|
|
||||||
tags: [ Groups ]
|
|
||||||
security:
|
|
||||||
- BearerAuth: [ ]
|
|
||||||
parameters:
|
|
||||||
- in: path
|
|
||||||
name: id
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
description: The Group ID
|
|
||||||
requestBody:
|
|
||||||
description: Update Group request using a list of json patch objects
|
|
||||||
content:
|
|
||||||
'application/json':
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/GroupPatchOperation'
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: A Group object
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Group'
|
|
||||||
'400':
|
|
||||||
"$ref": "#/components/responses/bad_request"
|
|
||||||
'401':
|
|
||||||
"$ref": "#/components/responses/requires_authentication"
|
|
||||||
'403':
|
|
||||||
"$ref": "#/components/responses/forbidden"
|
|
||||||
'500':
|
|
||||||
"$ref": "#/components/responses/internal_error"
|
|
||||||
delete:
|
delete:
|
||||||
summary: Delete a Group
|
summary: Delete a Group
|
||||||
tags: [ Groups ]
|
tags: [ Groups ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: groupId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Group ID
|
description: The unique identifier of a group
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: Delete status code
|
description: Delete status code
|
||||||
@@ -1511,6 +1428,7 @@ paths:
|
|||||||
tags: [ Rules ]
|
tags: [ Rules ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of Rules
|
description: A JSON Array of Rules
|
||||||
@@ -1533,6 +1451,7 @@ paths:
|
|||||||
tags: [ Rules ]
|
tags: [ Rules ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
requestBody:
|
requestBody:
|
||||||
description: New Rule request
|
description: New Rule request
|
||||||
content:
|
content:
|
||||||
@@ -1557,19 +1476,20 @@ paths:
|
|||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/Rule'
|
$ref: '#/components/schemas/Rule'
|
||||||
/api/rules/{id}:
|
/api/rules/{ruleId}:
|
||||||
get:
|
get:
|
||||||
summary: Get information about a Rules
|
summary: Get information about a Rules
|
||||||
tags: [ Rules ]
|
tags: [ Rules ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: ruleId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Rule ID
|
description: The unique identifier of a rule
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A Rule object
|
description: A Rule object
|
||||||
@@ -1590,13 +1510,14 @@ paths:
|
|||||||
tags: [ Rules ]
|
tags: [ Rules ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: ruleId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Rule ID
|
description: The unique identifier of a rule
|
||||||
requestBody:
|
requestBody:
|
||||||
description: Update Rule request
|
description: Update Rule request
|
||||||
content:
|
content:
|
||||||
@@ -1634,13 +1555,14 @@ paths:
|
|||||||
tags: [ Rules ]
|
tags: [ Rules ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: ruleId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Rule ID
|
description: The unique identifier of a rule
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: Delete status code
|
description: Delete status code
|
||||||
@@ -1659,6 +1581,7 @@ paths:
|
|||||||
tags: [ Policies ]
|
tags: [ Policies ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of Policies
|
description: A JSON Array of Policies
|
||||||
@@ -1681,6 +1604,7 @@ paths:
|
|||||||
tags: [ Policies ]
|
tags: [ Policies ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
requestBody:
|
requestBody:
|
||||||
description: New Policy request
|
description: New Policy request
|
||||||
content:
|
content:
|
||||||
@@ -1695,19 +1619,20 @@ paths:
|
|||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/Policy'
|
$ref: '#/components/schemas/Policy'
|
||||||
/api/policies/{id}:
|
/api/policies/{policyId}:
|
||||||
get:
|
get:
|
||||||
summary: Get information about a Policies
|
summary: Get information about a Policies
|
||||||
tags: [ Policies ]
|
tags: [ Policies ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: policyId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Policy ID
|
description: The unique identifier of a policy
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A Policy object
|
description: A Policy object
|
||||||
@@ -1728,13 +1653,14 @@ paths:
|
|||||||
tags: [ Policies ]
|
tags: [ Policies ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: policyId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Policy ID
|
description: The unique identifier of a policy
|
||||||
requestBody:
|
requestBody:
|
||||||
description: Update Policy request
|
description: Update Policy request
|
||||||
content:
|
content:
|
||||||
@@ -1762,13 +1688,14 @@ paths:
|
|||||||
tags: [ Policies ]
|
tags: [ Policies ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: policyId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Policy ID
|
description: The unique identifier of a policy
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: Delete status code
|
description: Delete status code
|
||||||
@@ -1787,6 +1714,7 @@ paths:
|
|||||||
tags: [ Routes ]
|
tags: [ Routes ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of Routes
|
description: A JSON Array of Routes
|
||||||
@@ -1809,6 +1737,7 @@ paths:
|
|||||||
tags: [ Routes ]
|
tags: [ Routes ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
requestBody:
|
requestBody:
|
||||||
description: New Routes request
|
description: New Routes request
|
||||||
content:
|
content:
|
||||||
@@ -1831,19 +1760,20 @@ paths:
|
|||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
|
||||||
/api/routes/{id}:
|
/api/routes/{routeId}:
|
||||||
get:
|
get:
|
||||||
summary: Get information about a Routes
|
summary: Get information about a Routes
|
||||||
tags: [ Routes ]
|
tags: [ Routes ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: routeId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Route ID
|
description: The unique identifier of a route
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A Route object
|
description: A Route object
|
||||||
@@ -1864,13 +1794,14 @@ paths:
|
|||||||
tags: [ Routes ]
|
tags: [ Routes ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: routeId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Route ID
|
description: The unique identifier of a route
|
||||||
requestBody:
|
requestBody:
|
||||||
description: Update Route request
|
description: Update Route request
|
||||||
content:
|
content:
|
||||||
@@ -1892,53 +1823,19 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
patch:
|
|
||||||
summary: Update information about a Route
|
|
||||||
tags: [ Routes ]
|
|
||||||
security:
|
|
||||||
- BearerAuth: [ ]
|
|
||||||
parameters:
|
|
||||||
- in: path
|
|
||||||
name: id
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
description: The Route ID
|
|
||||||
requestBody:
|
|
||||||
description: Update Route request using a list of json patch objects
|
|
||||||
content:
|
|
||||||
'application/json':
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/RoutePatchOperation'
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: A Route object
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Route'
|
|
||||||
'400':
|
|
||||||
"$ref": "#/components/responses/bad_request"
|
|
||||||
'401':
|
|
||||||
"$ref": "#/components/responses/requires_authentication"
|
|
||||||
'403':
|
|
||||||
"$ref": "#/components/responses/forbidden"
|
|
||||||
'500':
|
|
||||||
"$ref": "#/components/responses/internal_error"
|
|
||||||
delete:
|
delete:
|
||||||
summary: Delete a Route
|
summary: Delete a Route
|
||||||
tags: [ Routes ]
|
tags: [ Routes ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: routeId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Route ID
|
description: The unique identifier of a route
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: Delete status code
|
description: Delete status code
|
||||||
@@ -1957,6 +1854,7 @@ paths:
|
|||||||
tags: [ DNS ]
|
tags: [ DNS ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of Nameserver Groups
|
description: A JSON Array of Nameserver Groups
|
||||||
@@ -1979,6 +1877,7 @@ paths:
|
|||||||
tags: [ DNS ]
|
tags: [ DNS ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
requestBody:
|
requestBody:
|
||||||
description: New Nameserver Groups request
|
description: New Nameserver Groups request
|
||||||
content:
|
content:
|
||||||
@@ -2001,19 +1900,20 @@ paths:
|
|||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
|
||||||
/api/dns/nameservers/{id}:
|
/api/dns/nameservers/{nsgroupId}:
|
||||||
get:
|
get:
|
||||||
summary: Get information about a Nameserver Groups
|
summary: Get information about a Nameserver Groups
|
||||||
tags: [ DNS ]
|
tags: [ DNS ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: nsgroupId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Nameserver Group ID
|
description: The unique identifier of a Nameserver Group
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A Nameserver Group object
|
description: A Nameserver Group object
|
||||||
@@ -2034,13 +1934,14 @@ paths:
|
|||||||
tags: [ DNS ]
|
tags: [ DNS ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: nsgroupId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Nameserver Group ID
|
description: The unique identifier of a Nameserver Group
|
||||||
requestBody:
|
requestBody:
|
||||||
description: Update Nameserver Group request
|
description: Update Nameserver Group request
|
||||||
content:
|
content:
|
||||||
@@ -2062,53 +1963,19 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
patch:
|
|
||||||
summary: Update information about a Nameserver Group
|
|
||||||
tags: [ DNS ]
|
|
||||||
security:
|
|
||||||
- BearerAuth: [ ]
|
|
||||||
parameters:
|
|
||||||
- in: path
|
|
||||||
name: id
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
description: The Nameserver Group ID
|
|
||||||
requestBody:
|
|
||||||
description: Update Nameserver Group request using a list of json patch objects
|
|
||||||
content:
|
|
||||||
'application/json':
|
|
||||||
schema:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/NameserverGroupPatchOperation'
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: A Nameserver Group object
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/NameserverGroup'
|
|
||||||
'400':
|
|
||||||
"$ref": "#/components/responses/bad_request"
|
|
||||||
'401':
|
|
||||||
"$ref": "#/components/responses/requires_authentication"
|
|
||||||
'403':
|
|
||||||
"$ref": "#/components/responses/forbidden"
|
|
||||||
'500':
|
|
||||||
"$ref": "#/components/responses/internal_error"
|
|
||||||
delete:
|
delete:
|
||||||
summary: Delete a Nameserver Group
|
summary: Delete a Nameserver Group
|
||||||
tags: [ DNS ]
|
tags: [ DNS ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
parameters:
|
parameters:
|
||||||
- in: path
|
- in: path
|
||||||
name: id
|
name: nsgroupId
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
description: The Nameserver Group ID
|
description: The unique identifier of a Nameserver Group
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: Delete status code
|
description: Delete status code
|
||||||
@@ -2128,6 +1995,7 @@ paths:
|
|||||||
tags: [ DNS ]
|
tags: [ DNS ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Object of DNS Setting
|
description: A JSON Object of DNS Setting
|
||||||
@@ -2149,6 +2017,7 @@ paths:
|
|||||||
tags: [ DNS ]
|
tags: [ DNS ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
requestBody:
|
requestBody:
|
||||||
description: A DNS settings object
|
description: A DNS settings object
|
||||||
content:
|
content:
|
||||||
@@ -2176,6 +2045,7 @@ paths:
|
|||||||
tags: [ Events ]
|
tags: [ Events ]
|
||||||
security:
|
security:
|
||||||
- BearerAuth: [ ]
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A JSON Array of Events
|
description: A JSON Array of Events
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
BearerAuthScopes = "BearerAuth.Scopes"
|
BearerAuthScopes = "BearerAuth.Scopes"
|
||||||
|
TokenAuthScopes = "TokenAuth.Scopes"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for EventActivityCode.
|
// Defines values for EventActivityCode.
|
||||||
@@ -45,6 +46,7 @@ const (
|
|||||||
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
|
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
|
||||||
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
|
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
|
||||||
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
|
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
|
||||||
|
EventActivityCodeUserBlock EventActivityCode = "user.block"
|
||||||
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
|
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
|
||||||
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
|
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
|
||||||
EventActivityCodeUserInvite EventActivityCode = "user.invite"
|
EventActivityCodeUserInvite EventActivityCode = "user.invite"
|
||||||
@@ -52,19 +54,7 @@ const (
|
|||||||
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
|
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
|
||||||
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
|
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
|
||||||
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
|
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
|
||||||
)
|
EventActivityCodeUserUnblock EventActivityCode = "user.unblock"
|
||||||
|
|
||||||
// Defines values for GroupPatchOperationOp.
|
|
||||||
const (
|
|
||||||
GroupPatchOperationOpAdd GroupPatchOperationOp = "add"
|
|
||||||
GroupPatchOperationOpRemove GroupPatchOperationOp = "remove"
|
|
||||||
GroupPatchOperationOpReplace GroupPatchOperationOp = "replace"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Defines values for GroupPatchOperationPath.
|
|
||||||
const (
|
|
||||||
GroupPatchOperationPathName GroupPatchOperationPath = "name"
|
|
||||||
GroupPatchOperationPathPeers GroupPatchOperationPath = "peers"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for NameserverNsType.
|
// Defines values for NameserverNsType.
|
||||||
@@ -72,61 +62,17 @@ const (
|
|||||||
NameserverNsTypeUdp NameserverNsType = "udp"
|
NameserverNsTypeUdp NameserverNsType = "udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for NameserverGroupPatchOperationOp.
|
|
||||||
const (
|
|
||||||
NameserverGroupPatchOperationOpAdd NameserverGroupPatchOperationOp = "add"
|
|
||||||
NameserverGroupPatchOperationOpRemove NameserverGroupPatchOperationOp = "remove"
|
|
||||||
NameserverGroupPatchOperationOpReplace NameserverGroupPatchOperationOp = "replace"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Defines values for NameserverGroupPatchOperationPath.
|
|
||||||
const (
|
|
||||||
NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description"
|
|
||||||
NameserverGroupPatchOperationPathDomains NameserverGroupPatchOperationPath = "domains"
|
|
||||||
NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled"
|
|
||||||
NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups"
|
|
||||||
NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name"
|
|
||||||
NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers"
|
|
||||||
NameserverGroupPatchOperationPathPrimary NameserverGroupPatchOperationPath = "primary"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Defines values for PatchMinimumOp.
|
|
||||||
const (
|
|
||||||
PatchMinimumOpAdd PatchMinimumOp = "add"
|
|
||||||
PatchMinimumOpRemove PatchMinimumOp = "remove"
|
|
||||||
PatchMinimumOpReplace PatchMinimumOp = "replace"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Defines values for PolicyRuleAction.
|
// Defines values for PolicyRuleAction.
|
||||||
const (
|
const (
|
||||||
PolicyRuleActionAccept PolicyRuleAction = "accept"
|
PolicyRuleActionAccept PolicyRuleAction = "accept"
|
||||||
PolicyRuleActionDrop PolicyRuleAction = "drop"
|
PolicyRuleActionDrop PolicyRuleAction = "drop"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for RoutePatchOperationOp.
|
|
||||||
const (
|
|
||||||
RoutePatchOperationOpAdd RoutePatchOperationOp = "add"
|
|
||||||
RoutePatchOperationOpRemove RoutePatchOperationOp = "remove"
|
|
||||||
RoutePatchOperationOpReplace RoutePatchOperationOp = "replace"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Defines values for RoutePatchOperationPath.
|
|
||||||
const (
|
|
||||||
RoutePatchOperationPathDescription RoutePatchOperationPath = "description"
|
|
||||||
RoutePatchOperationPathEnabled RoutePatchOperationPath = "enabled"
|
|
||||||
RoutePatchOperationPathGroups RoutePatchOperationPath = "groups"
|
|
||||||
RoutePatchOperationPathMasquerade RoutePatchOperationPath = "masquerade"
|
|
||||||
RoutePatchOperationPathMetric RoutePatchOperationPath = "metric"
|
|
||||||
RoutePatchOperationPathNetwork RoutePatchOperationPath = "network"
|
|
||||||
RoutePatchOperationPathNetworkId RoutePatchOperationPath = "network_id"
|
|
||||||
RoutePatchOperationPathPeer RoutePatchOperationPath = "peer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Defines values for UserStatus.
|
// Defines values for UserStatus.
|
||||||
const (
|
const (
|
||||||
UserStatusActive UserStatus = "active"
|
UserStatusActive UserStatus = "active"
|
||||||
UserStatusDisabled UserStatus = "disabled"
|
UserStatusBlocked UserStatus = "blocked"
|
||||||
UserStatusInvited UserStatus = "invited"
|
UserStatusInvited UserStatus = "invited"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account defines model for Account.
|
// Account defines model for Account.
|
||||||
@@ -205,24 +151,6 @@ type GroupMinimum struct {
|
|||||||
PeersCount int `json:"peers_count"`
|
PeersCount int `json:"peers_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupPatchOperation defines model for GroupPatchOperation.
|
|
||||||
type GroupPatchOperation struct {
|
|
||||||
// Op Patch operation type
|
|
||||||
Op GroupPatchOperationOp `json:"op"`
|
|
||||||
|
|
||||||
// Path Group field to update in form /<field>
|
|
||||||
Path GroupPatchOperationPath `json:"path"`
|
|
||||||
|
|
||||||
// Value Values to be applied
|
|
||||||
Value []string `json:"value"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupPatchOperationOp Patch operation type
|
|
||||||
type GroupPatchOperationOp string
|
|
||||||
|
|
||||||
// GroupPatchOperationPath Group field to update in form /<field>
|
|
||||||
type GroupPatchOperationPath string
|
|
||||||
|
|
||||||
// Nameserver defines model for Nameserver.
|
// Nameserver defines model for Nameserver.
|
||||||
type Nameserver struct {
|
type Nameserver struct {
|
||||||
// Ip Nameserver IP
|
// Ip Nameserver IP
|
||||||
@@ -265,24 +193,6 @@ type NameserverGroup struct {
|
|||||||
Primary bool `json:"primary"`
|
Primary bool `json:"primary"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation.
|
|
||||||
type NameserverGroupPatchOperation struct {
|
|
||||||
// Op Patch operation type
|
|
||||||
Op NameserverGroupPatchOperationOp `json:"op"`
|
|
||||||
|
|
||||||
// Path Nameserver group field to update in form /<field>
|
|
||||||
Path NameserverGroupPatchOperationPath `json:"path"`
|
|
||||||
|
|
||||||
// Value Values to be applied
|
|
||||||
Value []string `json:"value"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NameserverGroupPatchOperationOp Patch operation type
|
|
||||||
type NameserverGroupPatchOperationOp string
|
|
||||||
|
|
||||||
// NameserverGroupPatchOperationPath Nameserver group field to update in form /<field>
|
|
||||||
type NameserverGroupPatchOperationPath string
|
|
||||||
|
|
||||||
// NameserverGroupRequest defines model for NameserverGroupRequest.
|
// NameserverGroupRequest defines model for NameserverGroupRequest.
|
||||||
type NameserverGroupRequest struct {
|
type NameserverGroupRequest struct {
|
||||||
// Description Nameserver group description
|
// Description Nameserver group description
|
||||||
@@ -307,18 +217,6 @@ type NameserverGroupRequest struct {
|
|||||||
Primary bool `json:"primary"`
|
Primary bool `json:"primary"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchMinimum defines model for PatchMinimum.
|
|
||||||
type PatchMinimum struct {
|
|
||||||
// Op Patch operation type
|
|
||||||
Op PatchMinimumOp `json:"op"`
|
|
||||||
|
|
||||||
// Value Values to be applied
|
|
||||||
Value []string `json:"value"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// PatchMinimumOp Patch operation type
|
|
||||||
type PatchMinimumOp string
|
|
||||||
|
|
||||||
// Peer defines model for Peer.
|
// Peer defines model for Peer.
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
// Connected Peer to Management connection status
|
// Connected Peer to Management connection status
|
||||||
@@ -516,24 +414,6 @@ type Route struct {
|
|||||||
Peer string `json:"peer"`
|
Peer string `json:"peer"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoutePatchOperation defines model for RoutePatchOperation.
|
|
||||||
type RoutePatchOperation struct {
|
|
||||||
// Op Patch operation type
|
|
||||||
Op RoutePatchOperationOp `json:"op"`
|
|
||||||
|
|
||||||
// Path Route field to update in form /<field>
|
|
||||||
Path RoutePatchOperationPath `json:"path"`
|
|
||||||
|
|
||||||
// Value Values to be applied
|
|
||||||
Value []string `json:"value"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoutePatchOperationOp Patch operation type
|
|
||||||
type RoutePatchOperationOp string
|
|
||||||
|
|
||||||
// RoutePatchOperationPath Route field to update in form /<field>
|
|
||||||
type RoutePatchOperationPath string
|
|
||||||
|
|
||||||
// RouteRequest defines model for RouteRequest.
|
// RouteRequest defines model for RouteRequest.
|
||||||
type RouteRequest struct {
|
type RouteRequest struct {
|
||||||
// Description Route description
|
// Description Route description
|
||||||
@@ -674,6 +554,9 @@ type User struct {
|
|||||||
// Id User ID
|
// Id User ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// IsBlocked Is true if this user is blocked. Blocked users can't use the system
|
||||||
|
IsBlocked bool `json:"is_blocked"`
|
||||||
|
|
||||||
// IsCurrent Is true if authenticated user is the same as this user
|
// IsCurrent Is true if authenticated user is the same as this user
|
||||||
IsCurrent *bool `json:"is_current,omitempty"`
|
IsCurrent *bool `json:"is_current,omitempty"`
|
||||||
|
|
||||||
@@ -716,35 +599,32 @@ type UserRequest struct {
|
|||||||
// AutoGroups Groups to auto-assign to peers registered by this user
|
// AutoGroups Groups to auto-assign to peers registered by this user
|
||||||
AutoGroups []string `json:"auto_groups"`
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
|
// IsBlocked If set to true then user is blocked and can't use the system
|
||||||
|
IsBlocked bool `json:"is_blocked"`
|
||||||
|
|
||||||
// Role User's NetBird account role
|
// Role User's NetBird account role
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PutApiAccountsIdJSONBody defines parameters for PutApiAccountsId.
|
// PutApiAccountsAccountIdJSONBody defines parameters for PutApiAccountsAccountId.
|
||||||
type PutApiAccountsIdJSONBody struct {
|
type PutApiAccountsAccountIdJSONBody struct {
|
||||||
Settings AccountSettings `json:"settings"`
|
Settings AccountSettings `json:"settings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchApiDnsNameserversIdJSONBody defines parameters for PatchApiDnsNameserversId.
|
|
||||||
type PatchApiDnsNameserversIdJSONBody = []NameserverGroupPatchOperation
|
|
||||||
|
|
||||||
// PostApiGroupsJSONBody defines parameters for PostApiGroups.
|
// PostApiGroupsJSONBody defines parameters for PostApiGroups.
|
||||||
type PostApiGroupsJSONBody struct {
|
type PostApiGroupsJSONBody struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Peers *[]string `json:"peers,omitempty"`
|
Peers *[]string `json:"peers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchApiGroupsIdJSONBody defines parameters for PatchApiGroupsId.
|
// PutApiGroupsGroupIdJSONBody defines parameters for PutApiGroupsGroupId.
|
||||||
type PatchApiGroupsIdJSONBody = []GroupPatchOperation
|
type PutApiGroupsGroupIdJSONBody struct {
|
||||||
|
|
||||||
// PutApiGroupsIdJSONBody defines parameters for PutApiGroupsId.
|
|
||||||
type PutApiGroupsIdJSONBody struct {
|
|
||||||
Name *string `json:"Name,omitempty"`
|
Name *string `json:"Name,omitempty"`
|
||||||
Peers *[]string `json:"Peers,omitempty"`
|
Peers *[]string `json:"Peers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PutApiPeersIdJSONBody defines parameters for PutApiPeersId.
|
// PutApiPeersPeerIdJSONBody defines parameters for PutApiPeersPeerId.
|
||||||
type PutApiPeersIdJSONBody struct {
|
type PutApiPeersPeerIdJSONBody struct {
|
||||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
@@ -753,11 +633,8 @@ type PutApiPeersIdJSONBody struct {
|
|||||||
// PostApiPoliciesJSONBody defines parameters for PostApiPolicies.
|
// PostApiPoliciesJSONBody defines parameters for PostApiPolicies.
|
||||||
type PostApiPoliciesJSONBody = PolicyMinimum
|
type PostApiPoliciesJSONBody = PolicyMinimum
|
||||||
|
|
||||||
// PutApiPoliciesIdJSONBody defines parameters for PutApiPoliciesId.
|
// PutApiPoliciesPolicyIdJSONBody defines parameters for PutApiPoliciesPolicyId.
|
||||||
type PutApiPoliciesIdJSONBody = PolicyMinimum
|
type PutApiPoliciesPolicyIdJSONBody = PolicyMinimum
|
||||||
|
|
||||||
// PatchApiRoutesIdJSONBody defines parameters for PatchApiRoutesId.
|
|
||||||
type PatchApiRoutesIdJSONBody = []RoutePatchOperation
|
|
||||||
|
|
||||||
// PostApiRulesJSONBody defines parameters for PostApiRules.
|
// PostApiRulesJSONBody defines parameters for PostApiRules.
|
||||||
type PostApiRulesJSONBody struct {
|
type PostApiRulesJSONBody struct {
|
||||||
@@ -776,8 +653,8 @@ type PostApiRulesJSONBody struct {
|
|||||||
Sources *[]string `json:"sources,omitempty"`
|
Sources *[]string `json:"sources,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PutApiRulesIdJSONBody defines parameters for PutApiRulesId.
|
// PutApiRulesRuleIdJSONBody defines parameters for PutApiRulesRuleId.
|
||||||
type PutApiRulesIdJSONBody struct {
|
type PutApiRulesRuleIdJSONBody struct {
|
||||||
// Description Rule friendly description
|
// Description Rule friendly description
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Destinations *[]string `json:"destinations,omitempty"`
|
Destinations *[]string `json:"destinations,omitempty"`
|
||||||
@@ -795,21 +672,18 @@ type PutApiRulesIdJSONBody struct {
|
|||||||
|
|
||||||
// GetApiUsersParams defines parameters for GetApiUsers.
|
// GetApiUsersParams defines parameters for GetApiUsers.
|
||||||
type GetApiUsersParams struct {
|
type GetApiUsersParams struct {
|
||||||
// ServiceUser Filters users and returns either normal users or service users
|
// ServiceUser Filters users and returns either regular users or service users
|
||||||
ServiceUser *bool `form:"service_user,omitempty" json:"service_user,omitempty"`
|
ServiceUser *bool `form:"service_user,omitempty" json:"service_user,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PutApiAccountsIdJSONRequestBody defines body for PutApiAccountsId for application/json ContentType.
|
// PutApiAccountsAccountIdJSONRequestBody defines body for PutApiAccountsAccountId for application/json ContentType.
|
||||||
type PutApiAccountsIdJSONRequestBody PutApiAccountsIdJSONBody
|
type PutApiAccountsAccountIdJSONRequestBody PutApiAccountsAccountIdJSONBody
|
||||||
|
|
||||||
// PostApiDnsNameserversJSONRequestBody defines body for PostApiDnsNameservers for application/json ContentType.
|
// PostApiDnsNameserversJSONRequestBody defines body for PostApiDnsNameservers for application/json ContentType.
|
||||||
type PostApiDnsNameserversJSONRequestBody = NameserverGroupRequest
|
type PostApiDnsNameserversJSONRequestBody = NameserverGroupRequest
|
||||||
|
|
||||||
// PatchApiDnsNameserversIdJSONRequestBody defines body for PatchApiDnsNameserversId for application/json ContentType.
|
// PutApiDnsNameserversNsgroupIdJSONRequestBody defines body for PutApiDnsNameserversNsgroupId for application/json ContentType.
|
||||||
type PatchApiDnsNameserversIdJSONRequestBody = PatchApiDnsNameserversIdJSONBody
|
type PutApiDnsNameserversNsgroupIdJSONRequestBody = NameserverGroupRequest
|
||||||
|
|
||||||
// PutApiDnsNameserversIdJSONRequestBody defines body for PutApiDnsNameserversId for application/json ContentType.
|
|
||||||
type PutApiDnsNameserversIdJSONRequestBody = NameserverGroupRequest
|
|
||||||
|
|
||||||
// PutApiDnsSettingsJSONRequestBody defines body for PutApiDnsSettings for application/json ContentType.
|
// PutApiDnsSettingsJSONRequestBody defines body for PutApiDnsSettings for application/json ContentType.
|
||||||
type PutApiDnsSettingsJSONRequestBody = DNSSettings
|
type PutApiDnsSettingsJSONRequestBody = DNSSettings
|
||||||
@@ -817,47 +691,41 @@ type PutApiDnsSettingsJSONRequestBody = DNSSettings
|
|||||||
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
|
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
|
||||||
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
|
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
|
||||||
|
|
||||||
// PatchApiGroupsIdJSONRequestBody defines body for PatchApiGroupsId for application/json ContentType.
|
// PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType.
|
||||||
type PatchApiGroupsIdJSONRequestBody = PatchApiGroupsIdJSONBody
|
type PutApiGroupsGroupIdJSONRequestBody PutApiGroupsGroupIdJSONBody
|
||||||
|
|
||||||
// PutApiGroupsIdJSONRequestBody defines body for PutApiGroupsId for application/json ContentType.
|
// PutApiPeersPeerIdJSONRequestBody defines body for PutApiPeersPeerId for application/json ContentType.
|
||||||
type PutApiGroupsIdJSONRequestBody PutApiGroupsIdJSONBody
|
type PutApiPeersPeerIdJSONRequestBody PutApiPeersPeerIdJSONBody
|
||||||
|
|
||||||
// PutApiPeersIdJSONRequestBody defines body for PutApiPeersId for application/json ContentType.
|
|
||||||
type PutApiPeersIdJSONRequestBody PutApiPeersIdJSONBody
|
|
||||||
|
|
||||||
// PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType.
|
// PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType.
|
||||||
type PostApiPoliciesJSONRequestBody = PostApiPoliciesJSONBody
|
type PostApiPoliciesJSONRequestBody = PostApiPoliciesJSONBody
|
||||||
|
|
||||||
// PutApiPoliciesIdJSONRequestBody defines body for PutApiPoliciesId for application/json ContentType.
|
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
|
||||||
type PutApiPoliciesIdJSONRequestBody = PutApiPoliciesIdJSONBody
|
type PutApiPoliciesPolicyIdJSONRequestBody = PutApiPoliciesPolicyIdJSONBody
|
||||||
|
|
||||||
// PostApiRoutesJSONRequestBody defines body for PostApiRoutes for application/json ContentType.
|
// PostApiRoutesJSONRequestBody defines body for PostApiRoutes for application/json ContentType.
|
||||||
type PostApiRoutesJSONRequestBody = RouteRequest
|
type PostApiRoutesJSONRequestBody = RouteRequest
|
||||||
|
|
||||||
// PatchApiRoutesIdJSONRequestBody defines body for PatchApiRoutesId for application/json ContentType.
|
// PutApiRoutesRouteIdJSONRequestBody defines body for PutApiRoutesRouteId for application/json ContentType.
|
||||||
type PatchApiRoutesIdJSONRequestBody = PatchApiRoutesIdJSONBody
|
type PutApiRoutesRouteIdJSONRequestBody = RouteRequest
|
||||||
|
|
||||||
// PutApiRoutesIdJSONRequestBody defines body for PutApiRoutesId for application/json ContentType.
|
|
||||||
type PutApiRoutesIdJSONRequestBody = RouteRequest
|
|
||||||
|
|
||||||
// PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType.
|
// PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType.
|
||||||
type PostApiRulesJSONRequestBody PostApiRulesJSONBody
|
type PostApiRulesJSONRequestBody PostApiRulesJSONBody
|
||||||
|
|
||||||
// PutApiRulesIdJSONRequestBody defines body for PutApiRulesId for application/json ContentType.
|
// PutApiRulesRuleIdJSONRequestBody defines body for PutApiRulesRuleId for application/json ContentType.
|
||||||
type PutApiRulesIdJSONRequestBody PutApiRulesIdJSONBody
|
type PutApiRulesRuleIdJSONRequestBody PutApiRulesRuleIdJSONBody
|
||||||
|
|
||||||
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
|
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
|
||||||
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
|
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
|
||||||
|
|
||||||
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
|
// PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType.
|
||||||
type PutApiSetupKeysIdJSONRequestBody = SetupKeyRequest
|
type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
|
||||||
|
|
||||||
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
|
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
|
||||||
type PostApiUsersJSONRequestBody = UserCreateRequest
|
type PostApiUsersJSONRequestBody = UserCreateRequest
|
||||||
|
|
||||||
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
|
// PutApiUsersUserIdJSONRequestBody defines body for PutApiUsersUserId for application/json ContentType.
|
||||||
type PutApiUsersIdJSONRequestBody = UserRequest
|
type PutApiUsersUserIdJSONRequestBody = UserRequest
|
||||||
|
|
||||||
// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType.
|
// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType.
|
||||||
type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest
|
type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
groupID, ok := vars["id"]
|
groupID, ok := vars["groupId"]
|
||||||
if !ok {
|
if !ok {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
|
||||||
return
|
return
|
||||||
@@ -88,7 +88,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiGroupsIdJSONRequestBody
|
var req api.PutApiGroupsGroupIdJSONRequestBody
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@@ -121,110 +121,6 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchGroup handles patch updates to a group identified by a given ID
|
|
||||||
func (h *GroupsHandler) PatchGroup(w http.ResponseWriter, r *http.Request) {
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
|
||||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
|
||||||
groupID := vars["id"]
|
|
||||||
if len(groupID) == 0 {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok := account.Groups[groupID]
|
|
||||||
if !ok {
|
|
||||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find group ID %s", groupID), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req api.PatchApiGroupsIdJSONRequestBody
|
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req) == 0 {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var operations []server.GroupUpdateOperation
|
|
||||||
|
|
||||||
for _, patch := range req {
|
|
||||||
switch patch.Path {
|
|
||||||
case api.GroupPatchOperationPathName:
|
|
||||||
if patch.Op != api.GroupPatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"name field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(patch.Value) == 0 || patch.Value[0] == "" {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
operations = append(operations, server.GroupUpdateOperation{
|
|
||||||
Type: server.UpdateGroupName,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.GroupPatchOperationPathPeers:
|
|
||||||
switch patch.Op {
|
|
||||||
case api.GroupPatchOperationOpReplace:
|
|
||||||
peerKeys := peerIPsToKeys(account, &patch.Value)
|
|
||||||
operations = append(operations, server.GroupUpdateOperation{
|
|
||||||
Type: server.UpdateGroupPeers,
|
|
||||||
Values: peerKeys,
|
|
||||||
})
|
|
||||||
case api.GroupPatchOperationOpRemove:
|
|
||||||
peerKeys := peerIPsToKeys(account, &patch.Value)
|
|
||||||
operations = append(operations, server.GroupUpdateOperation{
|
|
||||||
Type: server.RemovePeersFromGroup,
|
|
||||||
Values: peerKeys,
|
|
||||||
})
|
|
||||||
case api.GroupPatchOperationOpAdd:
|
|
||||||
operations = append(operations, server.GroupUpdateOperation{
|
|
||||||
Type: server.InsertPeersToGroup,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"invalid operation, \"%v\", for PeersHandler field", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(w, toGroupResponse(account, group))
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateGroup handles group creation request
|
// CreateGroup handles group creation request
|
||||||
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
@@ -277,7 +173,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
aID := account.Id
|
aID := account.Id
|
||||||
|
|
||||||
groupID := mux.Vars(r)["id"]
|
groupID := mux.Vars(r)["groupId"]
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||||
return
|
return
|
||||||
@@ -314,7 +210,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
groupID := mux.Vars(r)["id"]
|
groupID := mux.Vars(r)["groupId"]
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||||
return
|
return
|
||||||
@@ -335,29 +231,6 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func peerIPsToKeys(account *server.Account, peerIPs *[]string) []string {
|
|
||||||
var mappedPeerKeys []string
|
|
||||||
if peerIPs == nil {
|
|
||||||
return mappedPeerKeys
|
|
||||||
}
|
|
||||||
|
|
||||||
peersChecked := make(map[string]struct{})
|
|
||||||
|
|
||||||
for _, requestPeersIP := range *peerIPs {
|
|
||||||
_, ok := peersChecked[requestPeersIP]
|
|
||||||
if ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
peersChecked[requestPeersIP] = struct{}{}
|
|
||||||
for _, accountPeer := range account.Peers {
|
|
||||||
if accountPeer.IP.String() == requestPeersIP {
|
|
||||||
mappedPeerKeys = append(mappedPeerKeys, accountPeer.Key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mappedPeerKeys
|
|
||||||
}
|
|
||||||
|
|
||||||
func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
|
func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
|
||||||
cache := make(map[string]api.PeerMinimum)
|
cache := make(map[string]api.PeerMinimum)
|
||||||
gr := api.Group{
|
gr := api.Group{
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ func TestGetGroup(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/groups/{id}", p.GetGroup).Methods("GET")
|
router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
@@ -230,53 +230,6 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Write Group PATCH Name OK",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/groups/id-existed",
|
|
||||||
requestBody: bytes.NewBuffer(
|
|
||||||
[]byte(`[{"op":"replace","path":"name","value":["Default POSTed Group"]}]`)),
|
|
||||||
expectedStatus: http.StatusOK,
|
|
||||||
expectedGroup: &api.Group{
|
|
||||||
Id: "id-existed",
|
|
||||||
Name: "Default POSTed Group",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Write Group PATCH Invalid Name OP",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/groups/id-existed",
|
|
||||||
requestBody: bytes.NewBuffer(
|
|
||||||
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
|
|
||||||
expectedStatus: http.StatusUnprocessableEntity,
|
|
||||||
expectedBody: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Write Group PATCH Invalid Name",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/groups/id-existed",
|
|
||||||
requestBody: bytes.NewBuffer(
|
|
||||||
[]byte(`[{"op":"replace","path":"name","value":[]}]`)),
|
|
||||||
expectedStatus: http.StatusUnprocessableEntity,
|
|
||||||
expectedBody: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Write Group PATCH PeersHandler OK",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/groups/id-existed",
|
|
||||||
requestBody: bytes.NewBuffer(
|
|
||||||
[]byte(`[{"op":"replace","path":"peers","value":["100.100.100.100","200.200.200.200"]}]`)),
|
|
||||||
expectedStatus: http.StatusOK,
|
|
||||||
expectedBody: true,
|
|
||||||
expectedGroup: &api.Group{
|
|
||||||
Id: "id-existed",
|
|
||||||
PeersCount: 2,
|
|
||||||
Peers: []api.PeerMinimum{
|
|
||||||
{Id: "peer-A-ID"},
|
|
||||||
{Id: "peer-B-ID"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user")
|
||||||
@@ -289,8 +242,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST")
|
router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST")
|
||||||
router.HandleFunc("/api/groups/{id}", p.UpdateGroup).Methods("PUT")
|
router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT")
|
||||||
router.HandleFunc("/api/groups/{id}", p.PatchGroup).Methods("PATCH")
|
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
|
|||||||
acMiddleware := middleware.NewAccessControl(
|
acMiddleware := middleware.NewAccessControl(
|
||||||
authCfg.Audience,
|
authCfg.Audience,
|
||||||
authCfg.UserIDClaim,
|
authCfg.UserIDClaim,
|
||||||
accountManager.IsUserAdmin)
|
accountManager.GetUser)
|
||||||
|
|
||||||
rootRouter := mux.NewRouter()
|
rootRouter := mux.NewRouter()
|
||||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||||
@@ -96,22 +96,22 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
|
|||||||
|
|
||||||
func (apiHandler *apiHandler) addAccountsEndpoint() {
|
func (apiHandler *apiHandler) addAccountsEndpoint() {
|
||||||
accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/accounts/{id}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addPeersEndpoint() {
|
func (apiHandler *apiHandler) addPeersEndpoint() {
|
||||||
peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
|
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addUsersEndpoint() {
|
func (apiHandler *apiHandler) addUsersEndpoint() {
|
||||||
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/users/{id}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,56 +127,53 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() {
|
|||||||
keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/setup-keys/{id}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/setup-keys/{id}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addRulesEndpoint() {
|
func (apiHandler *apiHandler) addRulesEndpoint() {
|
||||||
rulesHandler := NewRulesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
rulesHandler := NewRulesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/rules", rulesHandler.GetAllRules).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/rules", rulesHandler.GetAllRules).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/rules", rulesHandler.CreateRule).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/rules", rulesHandler.CreateRule).Methods("POST", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.UpdateRule).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.UpdateRule).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.GetRule).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.GetRule).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.DeleteRule).Methods("DELETE", "OPTIONS")
|
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.DeleteRule).Methods("DELETE", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addPoliciesEndpoint() {
|
func (apiHandler *apiHandler) addPoliciesEndpoint() {
|
||||||
policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS")
|
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addGroupsEndpoint() {
|
func (apiHandler *apiHandler) addGroupsEndpoint() {
|
||||||
groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/groups/{id}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/groups/{id}", groupsHandler.PatchGroup).Methods("PATCH", "OPTIONS")
|
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/groups/{id}", groupsHandler.GetGroup).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/groups/{id}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addRoutesEndpoint() {
|
func (apiHandler *apiHandler) addRoutesEndpoint() {
|
||||||
routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/routes/{id}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/routes/{id}", routesHandler.PatchRoute).Methods("PATCH", "OPTIONS")
|
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/routes/{id}", routesHandler.GetRoute).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/routes/{id}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addDNSNameserversEndpoint() {
|
func (apiHandler *apiHandler) addDNSNameserversEndpoint() {
|
||||||
nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroup).Methods("PATCH", "OPTIONS")
|
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers/{id}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addDNSSettingEndpoint() {
|
func (apiHandler *apiHandler) addDNSSettingEndpoint() {
|
||||||
|
|||||||
@@ -6,28 +6,30 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IsUserAdminFunc func(userID string) (bool, error)
|
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||||
|
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
|
|
||||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||||
type AccessControl struct {
|
type AccessControl struct {
|
||||||
isUserAdmin IsUserAdminFunc
|
|
||||||
claimsExtract jwtclaims.ClaimsExtractor
|
claimsExtract jwtclaims.ClaimsExtractor
|
||||||
|
getUser GetUser
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccessControl instance constructor
|
// NewAccessControl instance constructor
|
||||||
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
|
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
|
||||||
return &AccessControl{
|
return &AccessControl{
|
||||||
isUserAdmin: isUserAdmin,
|
|
||||||
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(audience),
|
jwtclaims.WithAudience(audience),
|
||||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||||
),
|
),
|
||||||
|
getUser: getUser,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,23 +39,29 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := a.claimsExtract.FromRequestContext(r)
|
claims := a.claimsExtract.FromRequestContext(r)
|
||||||
|
|
||||||
ok, err := a.isUserAdmin(claims.UserId)
|
user, err := a.getUser(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !ok {
|
|
||||||
|
if user.IsBlocked() {
|
||||||
|
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.IsAdmin() {
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||||
|
|
||||||
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
|
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Regex failed")
|
log.Debugf("regex failed")
|
||||||
util.WriteError(status.Errorf(status.Internal, ""), w)
|
util.WriteError(status.Errorf(status.Internal, ""), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if ok {
|
if ok {
|
||||||
log.Debugf("Valid Path")
|
log.Debugf("valid Path")
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,13 +99,13 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroupID := mux.Vars(r)["id"]
|
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||||
if len(nsGroupID) == 0 {
|
if len(nsGroupID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiDnsNameserversIdJSONRequestBody
|
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@@ -140,88 +140,6 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
util.WriteJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchNameserverGroup handles patch updates to a nameserver group identified by a given ID
|
|
||||||
func (h *NameserversHandler) PatchNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
nsGroupID := mux.Vars(r)["id"]
|
|
||||||
if len(nsGroupID) == 0 {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req api.PatchApiDnsNameserversIdJSONRequestBody
|
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var operations []server.NameServerGroupUpdateOperation
|
|
||||||
for _, patch := range req {
|
|
||||||
if patch.Op != api.NameserverGroupPatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"nameserver groups only accepts replace operations, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
switch patch.Path {
|
|
||||||
case api.NameserverGroupPatchOperationPathName:
|
|
||||||
operations = append(operations, server.NameServerGroupUpdateOperation{
|
|
||||||
Type: server.UpdateNameServerGroupName,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.NameserverGroupPatchOperationPathDescription:
|
|
||||||
operations = append(operations, server.NameServerGroupUpdateOperation{
|
|
||||||
Type: server.UpdateNameServerGroupDescription,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.NameserverGroupPatchOperationPathPrimary:
|
|
||||||
operations = append(operations, server.NameServerGroupUpdateOperation{
|
|
||||||
Type: server.UpdateNameServerGroupPrimary,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.NameserverGroupPatchOperationPathDomains:
|
|
||||||
operations = append(operations, server.NameServerGroupUpdateOperation{
|
|
||||||
Type: server.UpdateNameServerGroupDomains,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.NameserverGroupPatchOperationPathNameservers:
|
|
||||||
operations = append(operations, server.NameServerGroupUpdateOperation{
|
|
||||||
Type: server.UpdateNameServerGroupNameServers,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.NameserverGroupPatchOperationPathGroups:
|
|
||||||
operations = append(operations, server.NameServerGroupUpdateOperation{
|
|
||||||
Type: server.UpdateNameServerGroupGroups,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.NameserverGroupPatchOperationPathEnabled:
|
|
||||||
operations = append(operations, server.NameServerGroupUpdateOperation{
|
|
||||||
Type: server.UpdateNameServerGroupEnabled,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, user.Id, operations)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := toNameserverGroupResponse(updatedNSGroup)
|
|
||||||
|
|
||||||
util.WriteJSONObject(w, &resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteNameserverGroup handles nameserver group deletion request
|
// DeleteNameserverGroup handles nameserver group deletion request
|
||||||
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
@@ -231,7 +149,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroupID := mux.Vars(r)["id"]
|
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||||
if len(nsGroupID) == 0 {
|
if len(nsGroupID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||||
return
|
return
|
||||||
@@ -256,7 +174,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroupID := mux.Vars(r)["id"]
|
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||||
if len(nsGroupID) == 0 {
|
if len(nsGroupID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -227,31 +227,6 @@ func TestNameserversHandlers(t *testing.T) {
|
|||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "PATCH OK",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
|
|
||||||
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
|
|
||||||
expectedStatus: http.StatusOK,
|
|
||||||
expectedBody: true,
|
|
||||||
expectedNSGroup: &api.NameserverGroup{
|
|
||||||
Id: existingNSGroupID,
|
|
||||||
Name: baseExistingNSGroup.Name,
|
|
||||||
Description: "NewDesc",
|
|
||||||
Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers,
|
|
||||||
Groups: baseExistingNSGroup.Groups,
|
|
||||||
Enabled: baseExistingNSGroup.Enabled,
|
|
||||||
Primary: baseExistingNSGroup.Primary,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PATCH Invalid Nameserver Group OK",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/dns/nameservers/" + notFoundRouteID,
|
|
||||||
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
|
|
||||||
expectedStatus: http.StatusNotFound,
|
|
||||||
expectedBody: false,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
p := initNameserversTestData()
|
p := initNameserversTestData()
|
||||||
@@ -262,11 +237,10 @@ func TestNameserversHandlers(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/dns/nameservers/{id}", p.GetNameserverGroup).Methods("GET")
|
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET")
|
||||||
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST")
|
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST")
|
||||||
router.HandleFunc("/api/dns/nameservers/{id}", p.DeleteNameserverGroup).Methods("DELETE")
|
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE")
|
||||||
router.HandleFunc("/api/dns/nameservers/{id}", p.UpdateNameserverGroup).Methods("PUT")
|
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT")
|
||||||
router.HandleFunc("/api/dns/nameservers/{id}", p.PatchNameserverGroup).Methods("PATCH")
|
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ var testAccount = &server.Account{
|
|||||||
func initPATTestData() *PATHandler {
|
func initPATTestData() *PATHandler {
|
||||||
return &PATHandler{
|
return &PATHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||||
}
|
}
|
||||||
@@ -79,7 +79,7 @@ func initPATTestData() *PATHandler {
|
|||||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return testAccount, testAccount.Users[existingUserID], nil
|
return testAccount, testAccount.Users[existingUserID], nil
|
||||||
},
|
},
|
||||||
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||||
}
|
}
|
||||||
@@ -91,7 +91,7 @@ func initPATTestData() *PATHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func initPATTestData() *PATHandler {
|
|||||||
}
|
}
|
||||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||||
},
|
},
|
||||||
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||||
req := &api.PutApiPeersIdJSONBody{}
|
req := &api.PutApiPeersPeerIdJSONBody{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@@ -78,7 +78,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
peerID := vars["id"]
|
peerID := vars["peerId"]
|
||||||
if len(peerID) == 0 {
|
if len(peerID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -146,8 +146,8 @@ func TestGetPeers(t *testing.T) {
|
|||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||||
router.HandleFunc("/api/peers/{id}", p.HandlePeer).Methods("GET")
|
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
|
||||||
router.HandleFunc("/api/peers/{id}", p.HandlePeer).Methods("PUT")
|
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
policyID := vars["id"]
|
policyID := vars["policyId"]
|
||||||
if len(policyID) == 0 {
|
if len(policyID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||||
return
|
return
|
||||||
@@ -78,7 +78,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiPoliciesIdJSONRequestBody
|
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@@ -214,7 +214,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
aID := account.Id
|
aID := account.Id
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
policyID := vars["id"]
|
policyID := vars["policyId"]
|
||||||
if len(policyID) == 0 {
|
if len(policyID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||||
return
|
return
|
||||||
@@ -240,7 +240,7 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
policyID := vars["id"]
|
policyID := vars["policyId"]
|
||||||
if len(policyID) == 0 {
|
if len(policyID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||||
return
|
return
|
||||||
@@ -103,7 +103,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
routeID := vars["id"]
|
routeID := vars["routeId"]
|
||||||
if len(routeID) == 0 {
|
if len(routeID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||||
return
|
return
|
||||||
@@ -115,7 +115,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiRoutesIdJSONRequestBody
|
var req api.PutApiRoutesRouteIdJSONRequestBody
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@@ -159,147 +159,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchRoute handles patch updates to a route identified by a given ID
|
|
||||||
func (h *RoutesHandler) PatchRoute(w http.ResponseWriter, r *http.Request) {
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
|
||||||
routeID := vars["id"]
|
|
||||||
if len(routeID) == 0 {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req api.PatchApiRoutesIdJSONRequestBody
|
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req) == 0 {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var operations []server.RouteUpdateOperation
|
|
||||||
|
|
||||||
for _, patch := range req {
|
|
||||||
switch patch.Path {
|
|
||||||
case api.RoutePatchOperationPathNetwork:
|
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"network field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
|
||||||
Type: server.UpdateRouteNetwork,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.RoutePatchOperationPathDescription:
|
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"description field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
|
||||||
Type: server.UpdateRouteDescription,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.RoutePatchOperationPathNetworkId:
|
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"network Identifier field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
|
||||||
Type: server.UpdateRouteNetworkIdentifier,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.RoutePatchOperationPathPeer:
|
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"peer field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(patch.Value) > 1 {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"value field only accepts 1 value, got %d", len(patch.Value)), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
|
||||||
Type: server.UpdateRoutePeer,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.RoutePatchOperationPathMetric:
|
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"metric field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
|
||||||
Type: server.UpdateRouteMetric,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.RoutePatchOperationPathMasquerade:
|
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"masquerade field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
|
||||||
Type: server.UpdateRouteMasquerade,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.RoutePatchOperationPathEnabled:
|
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"enabled field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
|
||||||
Type: server.UpdateRouteEnabled,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
case api.RoutePatchOperationPathGroups:
|
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
|
||||||
"groups field only accepts replace operation, got %s", patch.Op), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
|
||||||
Type: server.UpdateRouteGroups,
|
|
||||||
Values: patch.Value,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
root, err := h.accountManager.UpdateRoute(account.Id, routeID, operations)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := toRouteResponse(root)
|
|
||||||
|
|
||||||
util.WriteJSONObject(w, &resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRoute handles route deletion request
|
// DeleteRoute handles route deletion request
|
||||||
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
@@ -309,7 +168,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routeID := mux.Vars(r)["id"]
|
routeID := mux.Vars(r)["routeId"]
|
||||||
if len(routeID) == 0 {
|
if len(routeID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||||
return
|
return
|
||||||
@@ -333,7 +192,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routeID := mux.Vars(r)["id"]
|
routeID := mux.Vars(r)["routeId"]
|
||||||
if len(routeID) == 0 {
|
if len(routeID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -288,61 +288,6 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "PATCH Description OK",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/routes/" + existingRouteID,
|
|
||||||
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
|
|
||||||
expectedStatus: http.StatusOK,
|
|
||||||
expectedBody: true,
|
|
||||||
expectedRoute: &api.Route{
|
|
||||||
Id: existingRouteID,
|
|
||||||
Description: "NewDesc",
|
|
||||||
NetworkId: "awesomeNet",
|
|
||||||
Network: baseExistingRoute.Network.String(),
|
|
||||||
NetworkType: route.IPv4NetworkString,
|
|
||||||
Masquerade: baseExistingRoute.Masquerade,
|
|
||||||
Enabled: baseExistingRoute.Enabled,
|
|
||||||
Metric: baseExistingRoute.Metric,
|
|
||||||
Groups: baseExistingRoute.Groups,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PATCH Peer OK",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/routes/" + existingRouteID,
|
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", existingPeerID)),
|
|
||||||
expectedStatus: http.StatusOK,
|
|
||||||
expectedBody: true,
|
|
||||||
expectedRoute: &api.Route{
|
|
||||||
Id: existingRouteID,
|
|
||||||
Description: "NewDesc",
|
|
||||||
NetworkId: "awesomeNet",
|
|
||||||
Network: baseExistingRoute.Network.String(),
|
|
||||||
NetworkType: route.IPv4NetworkString,
|
|
||||||
Peer: existingPeerID,
|
|
||||||
Masquerade: baseExistingRoute.Masquerade,
|
|
||||||
Enabled: baseExistingRoute.Enabled,
|
|
||||||
Metric: baseExistingRoute.Metric,
|
|
||||||
Groups: baseExistingRoute.Groups,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PATCH Not Found Peer",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/routes/" + existingRouteID,
|
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)),
|
|
||||||
expectedStatus: http.StatusUnprocessableEntity,
|
|
||||||
expectedBody: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PATCH Not Found Route",
|
|
||||||
requestType: http.MethodPatch,
|
|
||||||
requestPath: "/api/routes/" + notFoundRouteID,
|
|
||||||
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"network\",\"value\":[\"192.168.0.0/34\"]}]"),
|
|
||||||
expectedStatus: http.StatusNotFound,
|
|
||||||
expectedBody: false,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
p := initRoutesTestData()
|
p := initRoutesTestData()
|
||||||
@@ -353,11 +298,10 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/routes/{id}", p.GetRoute).Methods("GET")
|
router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET")
|
||||||
router.HandleFunc("/api/routes/{id}", p.DeleteRoute).Methods("DELETE")
|
router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE")
|
||||||
router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST")
|
router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST")
|
||||||
router.HandleFunc("/api/routes/{id}", p.UpdateRoute).Methods("PUT")
|
router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT")
|
||||||
router.HandleFunc("/api/routes/{id}", p.PatchRoute).Methods("PATCH")
|
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
ruleID := vars["id"]
|
ruleID := vars["ruleId"]
|
||||||
if len(ruleID) == 0 {
|
if len(ruleID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||||
return
|
return
|
||||||
@@ -77,7 +77,7 @@ func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiRulesIdJSONRequestBody
|
var req api.PutApiRulesRuleIdJSONRequestBody
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@@ -210,7 +210,7 @@ func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
aID := account.Id
|
aID := account.Id
|
||||||
|
|
||||||
rID := mux.Vars(r)["id"]
|
rID := mux.Vars(r)["ruleId"]
|
||||||
if len(rID) == 0 {
|
if len(rID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||||
return
|
return
|
||||||
@@ -236,7 +236,7 @@ func (h *RulesHandler) GetRule(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
ruleID := mux.Vars(r)["id"]
|
ruleID := mux.Vars(r)["ruleId"]
|
||||||
if len(ruleID) == 0 {
|
if len(ruleID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ func TestRulesGetRule(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/rules/{id}", p.GetRule).Methods("GET")
|
router.HandleFunc("/api/rules/{ruleId}", p.GetRule).Methods("GET")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
@@ -235,7 +235,7 @@ func TestRulesWriteRule(t *testing.T) {
|
|||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/rules", p.CreateRule).Methods("POST")
|
router.HandleFunc("/api/rules", p.CreateRule).Methods("POST")
|
||||||
router.HandleFunc("/api/rules/{id}", p.UpdateRule).Methods("PUT")
|
router.HandleFunc("/api/rules/{ruleId}", p.UpdateRule).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
keyID := vars["id"]
|
keyID := vars["keyId"]
|
||||||
if len(keyID) == 0 {
|
if len(keyID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||||
return
|
return
|
||||||
@@ -109,13 +109,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
keyID := vars["id"]
|
keyID := vars["keyId"]
|
||||||
if len(keyID) == 0 {
|
if len(keyID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &api.PutApiSetupKeysIdJSONRequestBody{}
|
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
|||||||
@@ -174,8 +174,8 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS")
|
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS")
|
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys/{id}", handler.GetSetupKey).Methods("GET", "OPTIONS")
|
router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys/{id}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
|
|||||||
@@ -48,19 +48,24 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
userID := vars["id"]
|
userID := vars["userId"]
|
||||||
if len(userID) == 0 {
|
if len(userID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &api.PutApiUsersIdJSONRequestBody{}
|
req := &api.PutApiUsersUserIdJSONRequestBody{}
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.AutoGroups == nil {
|
||||||
|
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
userRole := server.StrRoleToUserRole(req.Role)
|
userRole := server.StrRoleToUserRole(req.Role)
|
||||||
if userRole == server.UserRoleUnknown {
|
if userRole == server.UserRoleUnknown {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||||
@@ -71,7 +76,9 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
Id: userID,
|
Id: userID,
|
||||||
Role: userRole,
|
Role: userRole,
|
||||||
AutoGroups: req.AutoGroups,
|
AutoGroups: req.AutoGroups,
|
||||||
|
Blocked: req.IsBlocked,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
@@ -94,7 +101,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
targetUserID := vars["id"]
|
targetUserID := vars["userId"]
|
||||||
if len(targetUserID) == 0 {
|
if len(targetUserID) == 0 {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
@@ -214,7 +221,11 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
|||||||
case "invited":
|
case "invited":
|
||||||
userStatus = api.UserStatusInvited
|
userStatus = api.UserStatusInvited
|
||||||
default:
|
default:
|
||||||
userStatus = api.UserStatusDisabled
|
userStatus = api.UserStatusBlocked
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsBlocked {
|
||||||
|
userStatus = api.UserStatusBlocked
|
||||||
}
|
}
|
||||||
|
|
||||||
isCurrent := user.ID == currenUserID
|
isCurrent := user.ID == currenUserID
|
||||||
@@ -227,5 +238,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
|||||||
Status: userStatus,
|
Status: userStatus,
|
||||||
IsCurrent: &isCurrent,
|
IsCurrent: &isCurrent,
|
||||||
IsServiceUser: &user.IsServiceUser,
|
IsServiceUser: &user.IsServiceUser,
|
||||||
|
IsBlocked: user.IsBlocked,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package http
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -31,16 +32,19 @@ var usersTestAccount = &server.Account{
|
|||||||
Id: existingUserID,
|
Id: existingUserID,
|
||||||
Role: "admin",
|
Role: "admin",
|
||||||
IsServiceUser: false,
|
IsServiceUser: false,
|
||||||
|
AutoGroups: []string{"group_1"},
|
||||||
},
|
},
|
||||||
regularUserID: {
|
regularUserID: {
|
||||||
Id: regularUserID,
|
Id: regularUserID,
|
||||||
Role: "user",
|
Role: "user",
|
||||||
IsServiceUser: false,
|
IsServiceUser: false,
|
||||||
|
AutoGroups: []string{"group_1"},
|
||||||
},
|
},
|
||||||
serviceUserID: {
|
serviceUserID: {
|
||||||
Id: serviceUserID,
|
Id: serviceUserID,
|
||||||
Role: "user",
|
Role: "user",
|
||||||
IsServiceUser: true,
|
IsServiceUser: true,
|
||||||
|
AutoGroups: []string{"group_1"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -70,7 +74,7 @@ func initUsersTestData() *UsersHandler {
|
|||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
},
|
},
|
||||||
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
|
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
if targetUserID == notFoundUserID {
|
if targetUserID == notFoundUserID {
|
||||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||||
}
|
}
|
||||||
@@ -79,6 +83,21 @@ func initUsersTestData() *UsersHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||||
|
if update.Id == notFoundUserID {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID != existingUserID {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := update.Copy().ToUserInfo(nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return info, nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
@@ -145,6 +164,122 @@ func TestGetUsers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateUser(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatusCode int
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
requestBody io.Reader
|
||||||
|
expectedUserID string
|
||||||
|
expectedRole string
|
||||||
|
expectedStatus string
|
||||||
|
expectedBlocked bool
|
||||||
|
expectedIsServiceUser bool
|
||||||
|
expectedGroups []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Update_Block_User",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/users/" + regularUserID,
|
||||||
|
expectedStatusCode: http.StatusOK,
|
||||||
|
expectedUserID: regularUserID,
|
||||||
|
expectedBlocked: true,
|
||||||
|
expectedRole: "user",
|
||||||
|
expectedStatus: "blocked",
|
||||||
|
expectedGroups: []string{"group_1"},
|
||||||
|
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update_Change_Role_To_Admin",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/users/" + regularUserID,
|
||||||
|
expectedStatusCode: http.StatusOK,
|
||||||
|
expectedUserID: regularUserID,
|
||||||
|
expectedBlocked: false,
|
||||||
|
expectedRole: "admin",
|
||||||
|
expectedStatus: "blocked",
|
||||||
|
expectedGroups: []string{"group_1"},
|
||||||
|
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update_Groups",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/users/" + regularUserID,
|
||||||
|
expectedStatusCode: http.StatusOK,
|
||||||
|
expectedUserID: regularUserID,
|
||||||
|
expectedBlocked: false,
|
||||||
|
expectedRole: "admin",
|
||||||
|
expectedStatus: "blocked",
|
||||||
|
expectedGroups: []string{"group_2", "group_3"},
|
||||||
|
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should_Fail_Because_AutoGroups_Is_Absent",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/users/" + regularUserID,
|
||||||
|
expectedStatusCode: http.StatusBadRequest,
|
||||||
|
expectedUserID: regularUserID,
|
||||||
|
expectedBlocked: false,
|
||||||
|
expectedRole: "admin",
|
||||||
|
expectedStatus: "blocked",
|
||||||
|
expectedGroups: []string{"group_2", "group_3"},
|
||||||
|
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
userHandler := initUsersTestData()
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if status := recorder.Code; status != tc.expectedStatusCode {
|
||||||
|
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||||
|
status, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectedStatusCode == 200 {
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("I don't know what I expected; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := &api.User{}
|
||||||
|
err = json.Unmarshal(content, &respBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("response content is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectedUserID, respBody.Id)
|
||||||
|
assert.Equal(t, tc.expectedRole, respBody.Role)
|
||||||
|
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
|
||||||
|
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
|
||||||
|
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
|
||||||
|
|
||||||
|
for _, expectedGroup := range tc.expectedGroups {
|
||||||
|
exists := false
|
||||||
|
for _, actualGroup := range respBody.AutoGroups {
|
||||||
|
if expectedGroup == actualGroup {
|
||||||
|
exists = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateUser(t *testing.T) {
|
func TestCreateUser(t *testing.T) {
|
||||||
name := "name"
|
name := "name"
|
||||||
email := "email"
|
email := "email"
|
||||||
@@ -219,21 +354,21 @@ func TestDeleteUser(t *testing.T) {
|
|||||||
name: "Delete Regular User",
|
name: "Delete Regular User",
|
||||||
requestType: http.MethodDelete,
|
requestType: http.MethodDelete,
|
||||||
requestPath: "/api/users/" + regularUserID,
|
requestPath: "/api/users/" + regularUserID,
|
||||||
requestVars: map[string]string{"id": regularUserID},
|
requestVars: map[string]string{"userId": regularUserID},
|
||||||
expectedStatus: http.StatusForbidden,
|
expectedStatus: http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Delete Service User",
|
name: "Delete Service User",
|
||||||
requestType: http.MethodDelete,
|
requestType: http.MethodDelete,
|
||||||
requestPath: "/api/users/" + serviceUserID,
|
requestPath: "/api/users/" + serviceUserID,
|
||||||
requestVars: map[string]string{"id": serviceUserID},
|
requestVars: map[string]string{"userId": serviceUserID},
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Delete Not Existing User",
|
name: "Delete Not Existing User",
|
||||||
requestType: http.MethodDelete,
|
requestType: http.MethodDelete,
|
||||||
requestPath: "/api/users/" + notFoundUserID,
|
requestPath: "/api/users/" + notFoundUserID,
|
||||||
requestVars: map[string]string{"id": notFoundUserID},
|
requestVars: map[string]string{"userId": notFoundUserID},
|
||||||
expectedStatus: http.StatusNotFound,
|
expectedStatus: http.StatusNotFound,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,10 +32,10 @@ type Auth0Manager struct {
|
|||||||
// Auth0ClientConfig auth0 manager client configurations
|
// Auth0ClientConfig auth0 manager client configurations
|
||||||
type Auth0ClientConfig struct {
|
type Auth0ClientConfig struct {
|
||||||
Audience string
|
Audience string
|
||||||
AuthIssuer string
|
AuthIssuer string `json:"-"`
|
||||||
ClientID string
|
ClientID string
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
GrantType string
|
GrantType string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// auth0JWTRequest payload struct to request a JWT Token
|
// auth0JWTRequest payload struct to request a JWT Token
|
||||||
@@ -110,7 +110,8 @@ type auth0Profile struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewAuth0Manager creates a new instance of the Auth0Manager
|
// NewAuth0Manager creates a new instance of the Auth0Manager
|
||||||
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
func NewAuth0Manager(oidcConfig OIDCConfig, config Auth0ClientConfig,
|
||||||
|
appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
||||||
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
@@ -121,17 +122,19 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
config.AuthIssuer = oidcConfig.TokenEndpoint
|
||||||
|
config.GrantType = "client_credentials"
|
||||||
|
|
||||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.Audience == "" || config.AuthIssuer == "" {
|
if config.ClientID == "" {
|
||||||
return nil, fmt.Errorf("auth0 idp configuration is not complete")
|
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, clientID is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.GrantType != "client_credentials" {
|
if config.ClientSecret == "" {
|
||||||
return nil, fmt.Errorf("auth0 idp configuration failed. Grant Type should be client_credentials")
|
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(strings.ToLower(config.AuthIssuer), "https://") {
|
if config.Audience == "" {
|
||||||
return nil, fmt.Errorf("auth0 idp configuration failed. AuthIssuer should contain https://")
|
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
credentials := &Auth0Credentials{
|
credentials := &Auth0Credentials{
|
||||||
|
|||||||
@@ -459,26 +459,9 @@ func TestNewAuth0Manager(t *testing.T) {
|
|||||||
testCase3Config := defaultTestConfig
|
testCase3Config := defaultTestConfig
|
||||||
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
||||||
|
|
||||||
testCase3 := test{
|
for _, testCase := range []test{testCase1, testCase2} {
|
||||||
name: "Wrong Auth Issuer Format",
|
|
||||||
inputConfig: testCase3Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when wrong auth issuer format",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase4Config := defaultTestConfig
|
|
||||||
testCase4Config.GrantType = "spa"
|
|
||||||
|
|
||||||
testCase4 := test{
|
|
||||||
name: "Wrong Grant Type",
|
|
||||||
inputConfig: testCase4Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when wrong grant type",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
_, err := NewAuth0Manager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
671
management/server/idp/azure.go
Normal file
671
management/server/idp/azure.go
Normal file
@@ -0,0 +1,671 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// azure extension properties template
|
||||||
|
wtAccountIDTpl = "extension_%s_wt_account_id"
|
||||||
|
wtPendingInviteTpl = "extension_%s_wt_pending_invite"
|
||||||
|
|
||||||
|
profileFields = "id,displayName,mail,userPrincipalName"
|
||||||
|
extensionFields = "id,name,targetObjects"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AzureManager azure manager client instance.
|
||||||
|
type AzureManager struct {
|
||||||
|
ClientID string
|
||||||
|
ObjectID string
|
||||||
|
GraphAPIEndpoint string
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
credentials ManagerCredentials
|
||||||
|
helper ManagerHelper
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// AzureClientConfig azure manager client configurations.
|
||||||
|
type AzureClientConfig struct {
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
ObjectID string
|
||||||
|
|
||||||
|
GraphAPIEndpoint string `json:"-"`
|
||||||
|
TokenEndpoint string `json:"-"`
|
||||||
|
GrantType string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AzureCredentials azure authentication information.
|
||||||
|
type AzureCredentials struct {
|
||||||
|
clientConfig AzureClientConfig
|
||||||
|
helper ManagerHelper
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
jwtToken JWTToken
|
||||||
|
mux sync.Mutex
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// azureProfile represents an azure user profile.
|
||||||
|
type azureProfile map[string]any
|
||||||
|
|
||||||
|
// passwordProfile represent authentication method for,
|
||||||
|
// newly created user profile.
|
||||||
|
type passwordProfile struct {
|
||||||
|
ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// azureExtension represent custom attribute,
|
||||||
|
// that can be added to user objects in Azure Active Directory (AD).
|
||||||
|
type azureExtension struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
DataType string `json:"dataType"`
|
||||||
|
TargetObjects []string `json:"targetObjects"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAzureManager creates a new instance of the AzureManager.
|
||||||
|
func NewAzureManager(oidcConfig OIDCConfig, config AzureClientConfig,
|
||||||
|
appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
helper := JsonParser{}
|
||||||
|
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||||
|
config.GraphAPIEndpoint = "https://graph.microsoft.com"
|
||||||
|
config.GrantType = "client_credentials"
|
||||||
|
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ClientSecret == "" {
|
||||||
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ObjectID == "" {
|
||||||
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := &AzureCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: httpClient,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &AzureManager{
|
||||||
|
ObjectID: config.ObjectID,
|
||||||
|
ClientID: config.ClientID,
|
||||||
|
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
||||||
|
httpClient: httpClient,
|
||||||
|
credentials: credentials,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.configureAppMetadata()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
||||||
|
func (ac *AzureCredentials) jwtStillValid() bool {
|
||||||
|
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestJWTToken performs request to get jwt token.
|
||||||
|
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", ac.clientConfig.ClientID)
|
||||||
|
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
||||||
|
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||||
|
data.Set("scope", "https://graph.microsoft.com/.default")
|
||||||
|
|
||||||
|
payload := strings.NewReader(data.Encode())
|
||||||
|
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
log.Debug("requesting new jwt token for azure idp manager")
|
||||||
|
|
||||||
|
resp, err := ac.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if ac.appMetrics != nil {
|
||||||
|
ac.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||||
|
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||||
|
jwtToken := JWTToken{}
|
||||||
|
body, err := io.ReadAll(rawBody)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ac.helper.Unmarshal(body, &jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||||
|
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exp maps into exp from jwt token
|
||||||
|
var IssuedAt struct{ Exp int64 }
|
||||||
|
err = ac.helper.Unmarshal(data, &IssuedAt)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||||
|
|
||||||
|
return jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate retrieves access token to use the azure Management API.
|
||||||
|
func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
ac.mux.Lock()
|
||||||
|
defer ac.mux.Unlock()
|
||||||
|
|
||||||
|
if ac.appMetrics != nil {
|
||||||
|
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
||||||
|
}
|
||||||
|
|
||||||
|
// reuse the token without requesting a new one if it is not expired,
|
||||||
|
// and if expiry time is sufficient time available to make a request.
|
||||||
|
if ac.jwtStillValid() {
|
||||||
|
return ac.jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ac.requestJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
return ac.jwtToken, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ac.jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ac.jwtToken = jwtToken
|
||||||
|
|
||||||
|
return ac.jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user in azure AD Idp.
|
||||||
|
func (am *AzureManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
|
||||||
|
payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := am.post("users", payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountCreateUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profile azureProfile
|
||||||
|
err = am.helper.Unmarshal(body, &profile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||||
|
profile[wtAccountIDField] = accountID
|
||||||
|
|
||||||
|
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||||
|
profile[wtPendingInviteField] = true
|
||||||
|
|
||||||
|
return profile.userData(am.ClientID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserDataByID requests user data from keycloak via ID.
|
||||||
|
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||||
|
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||||
|
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||||
|
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("$select", selectFields)
|
||||||
|
|
||||||
|
body, err := am.get("users/"+userID, q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profile azureProfile
|
||||||
|
err = am.helper.Unmarshal(body, &profile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return profile.userData(am.ClientID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByEmail searches users with a given email.
|
||||||
|
// If no users have been found, this function returns an empty list.
|
||||||
|
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||||
|
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||||
|
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||||
|
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("$select", selectFields)
|
||||||
|
|
||||||
|
body, err := am.get("users/"+email, q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profile azureProfile
|
||||||
|
err = am.helper.Unmarshal(body, &profile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0)
|
||||||
|
users = append(users, profile.userData(am.ClientID))
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount returns all the users for a given profile.
|
||||||
|
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||||
|
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||||
|
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||||
|
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("$select", selectFields)
|
||||||
|
q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID))
|
||||||
|
|
||||||
|
body, err := am.get("users", q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profiles struct{ Value []azureProfile }
|
||||||
|
err = am.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0)
|
||||||
|
for _, profile := range profiles.Value {
|
||||||
|
users = append(users, profile.userData(am.ClientID))
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||||
|
// It returns a list of users indexed by accountID.
|
||||||
|
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||||
|
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||||
|
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||||
|
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("$select", selectFields)
|
||||||
|
|
||||||
|
body, err := am.get("users", q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profiles struct{ Value []azureProfile }
|
||||||
|
err = am.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
indexedUsers := make(map[string][]*UserData)
|
||||||
|
for _, profile := range profiles.Value {
|
||||||
|
userData := profile.userData(am.ClientID)
|
||||||
|
|
||||||
|
accountID := userData.AppMetadata.WTAccountID
|
||||||
|
if accountID != "" {
|
||||||
|
if _, ok := indexedUsers[accountID]; !ok {
|
||||||
|
indexedUsers[accountID] = make([]*UserData, 0)
|
||||||
|
}
|
||||||
|
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return indexedUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAppMetadata updates user app metadata based on userID.
|
||||||
|
func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||||
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||||
|
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||||
|
|
||||||
|
data, err := am.helper.Marshal(map[string]any{
|
||||||
|
wtAccountIDField: appMetadata.WTAccountID,
|
||||||
|
wtPendingInviteField: appMetadata.WTPendingInvite,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := strings.NewReader(string(data))
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID)
|
||||||
|
req, err := http.NewRequest(http.MethodPatch, reqURL, payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
log.Debugf("updating idp metadata for user %s", userID)
|
||||||
|
|
||||||
|
resp, err := am.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
|
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("$select", extensionFields)
|
||||||
|
|
||||||
|
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
|
||||||
|
body, err := am.get(resource, q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var extensions struct{ Value []azureExtension }
|
||||||
|
err = am.helper.Unmarshal(body, &extensions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return extensions.Value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) {
|
||||||
|
extension := azureExtension{
|
||||||
|
Name: name,
|
||||||
|
DataType: "string",
|
||||||
|
TargetObjects: []string{"User"},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := am.helper.Marshal(extension)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
|
||||||
|
body, err := am.post(resource, string(payload))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var userExtension azureExtension
|
||||||
|
err = am.helper.Unmarshal(body, &userExtension)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &userExtension, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configureAppMetadata sets up app metadata extensions if they do not exists.
|
||||||
|
func (am *AzureManager) configureAppMetadata() error {
|
||||||
|
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||||
|
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||||
|
|
||||||
|
extensions, err := am.getUserExtensions()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the wt_account_id extension does not already exist, create it.
|
||||||
|
if !hasExtension(extensions, wtAccountIDField) {
|
||||||
|
_, err = am.createUserExtension(wtAccountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the wt_pending_invite extension does not already exist, create it.
|
||||||
|
if !hasExtension(extensions, wtPendingInviteField) {
|
||||||
|
_, err = am.createUserExtension(wtPendingInvite)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get perform Get requests.
|
||||||
|
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
|
||||||
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
|
||||||
|
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := am.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// post perform Post requests.
|
||||||
|
func (am *AzureManager) post(resource string, body string) ([]byte, error) {
|
||||||
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/%s", am.GraphAPIEndpoint, resource)
|
||||||
|
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := am.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
if am.appMetrics != nil {
|
||||||
|
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// userData construct user data from keycloak profile.
|
||||||
|
func (ap azureProfile) userData(clientID string) *UserData {
|
||||||
|
id, ok := ap["id"].(string)
|
||||||
|
if !ok {
|
||||||
|
id = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
email, ok := ap["userPrincipalName"].(string)
|
||||||
|
if !ok {
|
||||||
|
email = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
name, ok := ap["displayName"].(string)
|
||||||
|
if !ok {
|
||||||
|
name = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
accountIDField := extensionName(wtAccountIDTpl, clientID)
|
||||||
|
accountID, ok := ap[accountIDField].(string)
|
||||||
|
if !ok {
|
||||||
|
accountID = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingInviteField := extensionName(wtPendingInviteTpl, clientID)
|
||||||
|
pendingInvite, ok := ap[pendingInviteField].(bool)
|
||||||
|
if !ok {
|
||||||
|
pendingInvite = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserData{
|
||||||
|
Email: email,
|
||||||
|
Name: name,
|
||||||
|
ID: id,
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: accountID,
|
||||||
|
WTPendingInvite: &pendingInvite,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) {
|
||||||
|
wtAccountIDField := extensionName(wtAccountIDTpl, clientID)
|
||||||
|
wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID)
|
||||||
|
|
||||||
|
req := &azureProfile{
|
||||||
|
"accountEnabled": true,
|
||||||
|
"displayName": name,
|
||||||
|
"mailNickName": strings.Join(strings.Split(name, " "), ""),
|
||||||
|
"userPrincipalName": email,
|
||||||
|
"passwordProfile": passwordProfile{
|
||||||
|
ForceChangePasswordNextSignIn: true,
|
||||||
|
Password: GeneratePassword(8, 1, 1, 1),
|
||||||
|
},
|
||||||
|
wtAccountIDField: accountID,
|
||||||
|
wtPendingInviteField: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
str, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(str), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extensionName(extensionTpl, clientID string) string {
|
||||||
|
clientID = strings.ReplaceAll(clientID, "-", "")
|
||||||
|
return fmt.Sprintf(extensionTpl, clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasExtension checks whether a given extension by name,
|
||||||
|
// exists in an list of extensions.
|
||||||
|
func hasExtension(extensions []azureExtension, name string) bool {
|
||||||
|
for _, ext := range extensions {
|
||||||
|
if ext.Name == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
329
management/server/idp/azure_test.go
Normal file
329
management/server/idp/azure_test.go
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockAzureCredentials struct {
|
||||||
|
jwtToken JWTToken
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mockAzureCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
return mc.jwtToken, mc.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzureJwtStillValid(t *testing.T) {
|
||||||
|
type jwtStillValidTest struct {
|
||||||
|
name string
|
||||||
|
inputTime time.Time
|
||||||
|
expectedResult bool
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||||
|
name: "JWT still valid",
|
||||||
|
inputTime: time.Now().Add(10 * time.Second),
|
||||||
|
expectedResult: true,
|
||||||
|
message: "should be true",
|
||||||
|
}
|
||||||
|
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||||
|
name: "JWT is invalid",
|
||||||
|
inputTime: time.Now(),
|
||||||
|
expectedResult: false,
|
||||||
|
message: "should be false",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
config := AzureClientConfig{}
|
||||||
|
|
||||||
|
creds := AzureCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
}
|
||||||
|
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzureAuthenticate(t *testing.T) {
|
||||||
|
type authenticateTest struct {
|
||||||
|
name string
|
||||||
|
inputCode int
|
||||||
|
inputResBody string
|
||||||
|
inputExpireToken time.Time
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedFuncExitErrDiff error
|
||||||
|
expectedCode int
|
||||||
|
expectedToken string
|
||||||
|
}
|
||||||
|
exp := 5
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
authenticateTestCase1 := authenticateTest{
|
||||||
|
name: "Get Cached token",
|
||||||
|
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: nil,
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateTestCase2 := authenticateTest{
|
||||||
|
name: "Get Good JWT Response",
|
||||||
|
inputCode: 200,
|
||||||
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateTestCase3 := authenticateTest{
|
||||||
|
name: "Get Bad Status Code",
|
||||||
|
inputCode: 400,
|
||||||
|
inputResBody: "{}",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get azure token, statusCode 400"),
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
jwtReqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputResBody,
|
||||||
|
code: testCase.inputCode,
|
||||||
|
}
|
||||||
|
config := AzureClientConfig{}
|
||||||
|
|
||||||
|
creds := AzureCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: &jwtReqClient,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||||
|
|
||||||
|
_, err := creds.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
if testCase.expectedFuncExitErrDiff != nil {
|
||||||
|
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzureUpdateUserAppMetadata(t *testing.T) {
|
||||||
|
type updateUserAppMetadataTest struct {
|
||||||
|
name string
|
||||||
|
inputReqBody string
|
||||||
|
expectedReqBody string
|
||||||
|
appMetadata AppMetadata
|
||||||
|
statusCode int
|
||||||
|
helper ManagerHelper
|
||||||
|
managerCreds ManagerCredentials
|
||||||
|
assertErrFunc assert.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Authentication",
|
||||||
|
expectedReqBody: "",
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 400,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockAzureCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
err: fmt.Errorf("error"),
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Status Code",
|
||||||
|
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 400,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockAzureCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Response Parsing",
|
||||||
|
statusCode: 400,
|
||||||
|
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||||
|
managerCreds: &mockAzureCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||||
|
name: "Good request",
|
||||||
|
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 204,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockAzureCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
invite := true
|
||||||
|
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
||||||
|
name: "Update Pending Invite",
|
||||||
|
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID),
|
||||||
|
appMetadata: AppMetadata{
|
||||||
|
WTAccountID: "ok",
|
||||||
|
WTPendingInvite: &invite,
|
||||||
|
},
|
||||||
|
statusCode: 204,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockAzureCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||||
|
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
reqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputReqBody,
|
||||||
|
code: testCase.statusCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &AzureManager{
|
||||||
|
httpClient: &reqClient,
|
||||||
|
credentials: testCase.managerCreds,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzureProfile(t *testing.T) {
|
||||||
|
type azureProfileTest struct {
|
||||||
|
name string
|
||||||
|
clientID string
|
||||||
|
invite bool
|
||||||
|
inputProfile azureProfile
|
||||||
|
expectedUserData UserData
|
||||||
|
}
|
||||||
|
|
||||||
|
azureProfileTestCase1 := azureProfileTest{
|
||||||
|
name: "Good Request",
|
||||||
|
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
||||||
|
invite: false,
|
||||||
|
inputProfile: azureProfile{
|
||||||
|
"id": "test1",
|
||||||
|
"displayName": "John Doe",
|
||||||
|
"userPrincipalName": "test1@test.com",
|
||||||
|
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
||||||
|
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
|
||||||
|
},
|
||||||
|
expectedUserData: UserData{
|
||||||
|
Email: "test1@test.com",
|
||||||
|
Name: "John Doe",
|
||||||
|
ID: "test1",
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
azureProfileTestCase2 := azureProfileTest{
|
||||||
|
name: "Missing User ID",
|
||||||
|
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
||||||
|
invite: true,
|
||||||
|
inputProfile: azureProfile{
|
||||||
|
"displayName": "John Doe",
|
||||||
|
"userPrincipalName": "test2@test.com",
|
||||||
|
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
||||||
|
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true,
|
||||||
|
},
|
||||||
|
expectedUserData: UserData{
|
||||||
|
Email: "test2@test.com",
|
||||||
|
Name: "John Doe",
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
azureProfileTestCase3 := azureProfileTest{
|
||||||
|
name: "Missing User Name",
|
||||||
|
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
||||||
|
invite: false,
|
||||||
|
inputProfile: azureProfile{
|
||||||
|
"id": "test3",
|
||||||
|
"userPrincipalName": "test3@test.com",
|
||||||
|
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
||||||
|
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
|
||||||
|
},
|
||||||
|
expectedUserData: UserData{
|
||||||
|
ID: "test3",
|
||||||
|
Email: "test3@test.com",
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
azureProfileTestCase4 := azureProfileTest{
|
||||||
|
name: "Missing Extension Fields",
|
||||||
|
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
||||||
|
invite: false,
|
||||||
|
inputProfile: azureProfile{
|
||||||
|
"id": "test4",
|
||||||
|
"displayName": "John Doe",
|
||||||
|
"userPrincipalName": "test4@test.com",
|
||||||
|
},
|
||||||
|
expectedUserData: UserData{
|
||||||
|
ID: "test4",
|
||||||
|
Name: "John Doe",
|
||||||
|
Email: "test4@test.com",
|
||||||
|
AppMetadata: AppMetadata{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||||
|
userData := testCase.inputProfile.userData(testCase.clientID)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,11 +19,21 @@ type Manager interface {
|
|||||||
GetUserByEmail(email string) ([]*UserData, error)
|
GetUserByEmail(email string) ([]*UserData, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OIDCConfig specifies configuration for OpenID Connect provider
|
||||||
|
// These configurations are automatically loaded from the OIDC endpoint
|
||||||
|
type OIDCConfig struct {
|
||||||
|
Issuer string
|
||||||
|
TokenEndpoint string
|
||||||
|
}
|
||||||
|
|
||||||
// Config an idp configuration struct to be loaded from management server's config file
|
// Config an idp configuration struct to be loaded from management server's config file
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ManagerType string
|
ManagerType string
|
||||||
|
OIDCConfig OIDCConfig `json:"-"`
|
||||||
Auth0ClientCredentials Auth0ClientConfig
|
Auth0ClientCredentials Auth0ClientConfig
|
||||||
|
AzureClientCredentials AzureClientConfig
|
||||||
KeycloakClientCredentials KeycloakClientConfig
|
KeycloakClientCredentials KeycloakClientConfig
|
||||||
|
ZitadelClientCredentials ZitadelClientConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||||
@@ -72,9 +82,13 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
case "none", "":
|
case "none", "":
|
||||||
return nil, nil
|
return nil, nil
|
||||||
case "auth0":
|
case "auth0":
|
||||||
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
|
return NewAuth0Manager(config.OIDCConfig, config.Auth0ClientCredentials, appMetrics)
|
||||||
|
case "azure":
|
||||||
|
return NewAzureManager(config.OIDCConfig, config.AzureClientCredentials, appMetrics)
|
||||||
case "keycloak":
|
case "keycloak":
|
||||||
return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics)
|
return NewKeycloakManager(config.OIDCConfig, config.KeycloakClientCredentials, appMetrics)
|
||||||
|
case "zitadel":
|
||||||
|
return NewZitadelManager(config.OIDCConfig, config.ZitadelClientCredentials, appMetrics)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ type KeycloakClientConfig struct {
|
|||||||
ClientID string
|
ClientID string
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
AdminEndpoint string
|
AdminEndpoint string
|
||||||
TokenEndpoint string
|
TokenEndpoint string `json:"-"`
|
||||||
GrantType string
|
GrantType string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeycloakCredentials keycloak authentication information.
|
// KeycloakCredentials keycloak authentication information.
|
||||||
@@ -82,7 +82,8 @@ type keycloakProfile struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
||||||
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
func NewKeycloakManager(oidcConfig OIDCConfig, config KeycloakClientConfig,
|
||||||
|
appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
@@ -92,13 +93,19 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
|
|||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||||
|
config.GrantType = "client_credentials"
|
||||||
|
|
||||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" {
|
if config.ClientID == "" {
|
||||||
return nil, fmt.Errorf("keycloak idp configuration is not complete")
|
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.GrantType != "client_credentials" {
|
if config.ClientSecret == "" {
|
||||||
return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials")
|
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AdminEndpoint == "" {
|
||||||
|
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
credentials := &KeycloakCredentials{
|
credentials := &KeycloakCredentials{
|
||||||
|
|||||||
@@ -46,19 +46,19 @@ func TestNewKeycloakManager(t *testing.T) {
|
|||||||
assertErrFuncMessage: "should return error when field empty",
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
}
|
}
|
||||||
|
|
||||||
testCase5Config := defaultTestConfig
|
testCase3Config := defaultTestConfig
|
||||||
testCase5Config.GrantType = "authorization_code"
|
testCase3Config.ClientSecret = ""
|
||||||
|
|
||||||
testCase5 := test{
|
testCase3 := test{
|
||||||
name: "Wrong GrantType",
|
name: "Missing ClientSecret Configuration",
|
||||||
inputConfig: testCase5Config,
|
inputConfig: testCase3Config,
|
||||||
assertErrFunc: require.Error,
|
assertErrFunc: require.Error,
|
||||||
assertErrFuncMessage: "should return error when wrong grant type",
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase5} {
|
for _, testCase := range []test{testCase1, testCase2, testCase3} {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
_, err := NewKeycloakManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
609
management/server/idp/zitadel.go
Normal file
609
management/server/idp/zitadel.go
Normal file
@@ -0,0 +1,609 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ZitadelManager zitadel manager client instance.
|
||||||
|
type ZitadelManager struct {
|
||||||
|
managementEndpoint string
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
credentials ManagerCredentials
|
||||||
|
helper ManagerHelper
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZitadelClientConfig zitadel manager client configurations.
|
||||||
|
type ZitadelClientConfig struct {
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
GrantType string `json:"-"`
|
||||||
|
TokenEndpoint string `json:"-"`
|
||||||
|
ManagementEndpoint string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZitadelCredentials zitadel authentication information.
|
||||||
|
type ZitadelCredentials struct {
|
||||||
|
clientConfig ZitadelClientConfig
|
||||||
|
helper ManagerHelper
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
jwtToken JWTToken
|
||||||
|
mux sync.Mutex
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelEmail specifies details of a user email.
|
||||||
|
type zitadelEmail struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
IsEmailVerified bool `json:"isEmailVerified"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelUserInfo specifies user information.
|
||||||
|
type zitadelUserInfo struct {
|
||||||
|
FirstName string `json:"firstName"`
|
||||||
|
LastName string `json:"lastName"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelUser specifies profile details for user account.
|
||||||
|
type zitadelUser struct {
|
||||||
|
UserName string `json:"userName,omitempty"`
|
||||||
|
Profile zitadelUserInfo `json:"profile"`
|
||||||
|
Email zitadelEmail `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type zitadelAttributes map[string][]map[string]any
|
||||||
|
|
||||||
|
// zitadelMetadata holds additional user data.
|
||||||
|
type zitadelMetadata struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelProfile represents an zitadel user profile response.
|
||||||
|
type zitadelProfile struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
State string `json:"state"`
|
||||||
|
UserName string `json:"userName"`
|
||||||
|
PreferredLoginName string `json:"preferredLoginName"`
|
||||||
|
LoginNames []string `json:"loginNames"`
|
||||||
|
Human *zitadelUser `json:"human"`
|
||||||
|
Metadata []zitadelMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewZitadelManager creates a new instance of the ZitadelManager.
|
||||||
|
func NewZitadelManager(oidcConfig OIDCConfig, config ZitadelClientConfig,
|
||||||
|
appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
helper := JsonParser{}
|
||||||
|
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||||
|
config.ManagementEndpoint = fmt.Sprintf("%s/management/v1", oidcConfig.Issuer)
|
||||||
|
config.GrantType = "client_credentials"
|
||||||
|
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, clientID is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ClientSecret == "" {
|
||||||
|
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, ClientSecret is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := &ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: httpClient,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ZitadelManager{
|
||||||
|
managementEndpoint: config.ManagementEndpoint,
|
||||||
|
httpClient: httpClient,
|
||||||
|
credentials: credentials,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from zitadel.
|
||||||
|
func (zc *ZitadelCredentials) jwtStillValid() bool {
|
||||||
|
return !zc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(zc.jwtToken.expiresInTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestJWTToken performs request to get jwt token.
|
||||||
|
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", zc.clientConfig.ClientID)
|
||||||
|
data.Set("client_secret", zc.clientConfig.ClientSecret)
|
||||||
|
data.Set("grant_type", zc.clientConfig.GrantType)
|
||||||
|
data.Set("scope", "urn:zitadel:iam:org:project:id:zitadel:aud")
|
||||||
|
|
||||||
|
payload := strings.NewReader(data.Encode())
|
||||||
|
req, err := http.NewRequest(http.MethodPost, zc.clientConfig.TokenEndpoint, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
log.Debug("requesting new jwt token for zitadel idp manager")
|
||||||
|
|
||||||
|
resp, err := zc.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if zc.appMetrics != nil {
|
||||||
|
zc.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds.
|
||||||
|
func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||||
|
jwtToken := JWTToken{}
|
||||||
|
body, err := io.ReadAll(rawBody)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = zc.helper.Unmarshal(body, &jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||||
|
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exp maps into exp from jwt token
|
||||||
|
var IssuedAt struct{ Exp int64 }
|
||||||
|
err = zc.helper.Unmarshal(data, &IssuedAt)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||||
|
|
||||||
|
return jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate retrieves access token to use the Zitadel Management API.
|
||||||
|
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
zc.mux.Lock()
|
||||||
|
defer zc.mux.Unlock()
|
||||||
|
|
||||||
|
if zc.appMetrics != nil {
|
||||||
|
zc.appMetrics.IDPMetrics().CountAuthenticate()
|
||||||
|
}
|
||||||
|
|
||||||
|
// reuse the token without requesting a new one if it is not expired,
|
||||||
|
// and if expiry time is sufficient time available to make a request.
|
||||||
|
if zc.jwtStillValid() {
|
||||||
|
return zc.jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := zc.requestJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
return zc.jwtToken, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
jwtToken, err := zc.parseRequestJWTResponse(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return zc.jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
zc.jwtToken = jwtToken
|
||||||
|
|
||||||
|
return zc.jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user in zitadel Idp and sends an invite.
|
||||||
|
func (zm *ZitadelManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
|
||||||
|
payload, err := buildZitadelCreateUserRequestPayload(email, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := zm.post("users/human/_import", payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountCreateUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
}
|
||||||
|
err = zm.helper.Unmarshal(body, &result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
invite := true
|
||||||
|
appMetadata := AppMetadata{
|
||||||
|
WTAccountID: accountID,
|
||||||
|
WTPendingInvite: &invite,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add metadata to new user
|
||||||
|
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return zm.GetUserDataByID(result.UserID, appMetadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByEmail searches users with a given email.
|
||||||
|
// If no users have been found, this function returns an empty list.
|
||||||
|
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||||
|
searchByEmail := zitadelAttributes{
|
||||||
|
"queries": {
|
||||||
|
{
|
||||||
|
"emailQuery": map[string]any{
|
||||||
|
"emailAddress": email,
|
||||||
|
"method": "TEXT_QUERY_METHOD_EQUALS",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload, err := zm.helper.Marshal(searchByEmail)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := zm.post("users/_search", string(payload))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profiles struct{ Result []zitadelProfile }
|
||||||
|
err = zm.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0)
|
||||||
|
for _, profile := range profiles.Result {
|
||||||
|
metadata, err := zm.getUserMetadata(profile.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
profile.Metadata = metadata
|
||||||
|
|
||||||
|
users = append(users, profile.userData())
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserDataByID requests user data from zitadel via ID.
|
||||||
|
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||||
|
body, err := zm.get("users/"+userID, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profile struct{ User zitadelProfile }
|
||||||
|
err = zm.helper.Unmarshal(body, &profile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata, err := zm.getUserMetadata(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
profile.User.Metadata = metadata
|
||||||
|
|
||||||
|
return profile.User.userData(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount returns all the users for a given profile.
|
||||||
|
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||||
|
accounts, err := zm.GetAllAccounts()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountGetAccount()
|
||||||
|
}
|
||||||
|
|
||||||
|
return accounts[accountID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||||
|
// It returns a list of users indexed by accountID.
|
||||||
|
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||||
|
body, err := zm.post("users/_search", "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profiles struct{ Result []zitadelProfile }
|
||||||
|
err = zm.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
indexedUsers := make(map[string][]*UserData)
|
||||||
|
for _, profile := range profiles.Result {
|
||||||
|
// fetch user metadata
|
||||||
|
metadata, err := zm.getUserMetadata(profile.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
profile.Metadata = metadata
|
||||||
|
|
||||||
|
userData := profile.userData()
|
||||||
|
accountID := userData.AppMetadata.WTAccountID
|
||||||
|
|
||||||
|
if accountID != "" {
|
||||||
|
if _, ok := indexedUsers[accountID]; !ok {
|
||||||
|
indexedUsers[accountID] = make([]*UserData, 0)
|
||||||
|
}
|
||||||
|
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return indexedUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||||
|
// Metadata values are base64 encoded.
|
||||||
|
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||||
|
if appMetadata.WTPendingInvite == nil {
|
||||||
|
appMetadata.WTPendingInvite = new(bool)
|
||||||
|
}
|
||||||
|
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
|
||||||
|
|
||||||
|
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
|
||||||
|
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
|
||||||
|
|
||||||
|
metadata := zitadelAttributes{
|
||||||
|
"metadata": {
|
||||||
|
{
|
||||||
|
"key": wtAccountID,
|
||||||
|
"value": wtAccountIDValue,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": wtPendingInvite,
|
||||||
|
"value": wtPendingInviteValue,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload, err := zm.helper.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
|
||||||
|
_, err = zm.post(resource, string(payload))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserMetadata requests user metadata from zitadel via ID.
|
||||||
|
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
||||||
|
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
||||||
|
body, err := zm.post(resource, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var metadata struct{ Result []zitadelMetadata }
|
||||||
|
err = zm.helper.Unmarshal(body, &metadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata.Result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// post perform Post requests.
|
||||||
|
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||||
|
jwtToken, err := zm.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
|
||||||
|
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := zm.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// get perform Get requests.
|
||||||
|
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
||||||
|
jwtToken, err := zm.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/%s?%s", zm.managementEndpoint, resource, q.Encode())
|
||||||
|
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := zm.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// value returns string represented by the base64 string value.
|
||||||
|
func (zm zitadelMetadata) value() string {
|
||||||
|
value, err := base64.StdEncoding.DecodeString(zm.Value)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// userData construct user data from zitadel profile.
|
||||||
|
func (zp zitadelProfile) userData() *UserData {
|
||||||
|
var (
|
||||||
|
email string
|
||||||
|
name string
|
||||||
|
wtAccountIDValue string
|
||||||
|
wtPendingInviteValue bool
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, metadata := range zp.Metadata {
|
||||||
|
if metadata.Key == wtAccountID {
|
||||||
|
wtAccountIDValue = metadata.value()
|
||||||
|
}
|
||||||
|
|
||||||
|
if metadata.Key == wtPendingInvite {
|
||||||
|
value, err := strconv.ParseBool(metadata.value())
|
||||||
|
if err == nil {
|
||||||
|
wtPendingInviteValue = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain the email for the human account and the login name,
|
||||||
|
// for the machine account.
|
||||||
|
if zp.Human != nil {
|
||||||
|
email = zp.Human.Email.Email
|
||||||
|
name = zp.Human.Profile.DisplayName
|
||||||
|
} else {
|
||||||
|
if len(zp.LoginNames) > 0 {
|
||||||
|
email = zp.LoginNames[0]
|
||||||
|
name = zp.LoginNames[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserData{
|
||||||
|
Email: email,
|
||||||
|
Name: name,
|
||||||
|
ID: zp.ID,
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: wtAccountIDValue,
|
||||||
|
WTPendingInvite: &wtPendingInviteValue,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
|
||||||
|
var firstName, lastName string
|
||||||
|
|
||||||
|
words := strings.Fields(name)
|
||||||
|
if n := len(words); n > 0 {
|
||||||
|
firstName = strings.Join(words[:n-1], " ")
|
||||||
|
lastName = words[n-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &zitadelUser{
|
||||||
|
UserName: name,
|
||||||
|
Profile: zitadelUserInfo{
|
||||||
|
FirstName: strings.TrimSpace(firstName),
|
||||||
|
LastName: strings.TrimSpace(lastName),
|
||||||
|
DisplayName: name,
|
||||||
|
},
|
||||||
|
Email: zitadelEmail{
|
||||||
|
Email: email,
|
||||||
|
IsEmailVerified: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
str, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(str), nil
|
||||||
|
}
|
||||||
486
management/server/idp/zitadel_test.go
Normal file
486
management/server/idp/zitadel_test.go
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewZitadelManager(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
inputConfig ZitadelClientConfig
|
||||||
|
assertErrFunc require.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultTestConfig := ZitadelClientConfig{
|
||||||
|
ClientID: "client_id",
|
||||||
|
ClientSecret: "client_secret",
|
||||||
|
GrantType: "client_credentials",
|
||||||
|
TokenEndpoint: "http://localhost/oauth/v2/token",
|
||||||
|
ManagementEndpoint: "http://localhost/management/v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase1 := test{
|
||||||
|
name: "Good Configuration",
|
||||||
|
inputConfig: defaultTestConfig,
|
||||||
|
assertErrFunc: require.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase2Config := defaultTestConfig
|
||||||
|
testCase2Config.ClientID = ""
|
||||||
|
|
||||||
|
testCase2 := test{
|
||||||
|
name: "Missing ClientID Configuration",
|
||||||
|
inputConfig: testCase2Config,
|
||||||
|
assertErrFunc: require.Error,
|
||||||
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase3Config := defaultTestConfig
|
||||||
|
testCase3Config.ClientSecret = ""
|
||||||
|
|
||||||
|
testCase3 := test{
|
||||||
|
name: "Missing ClientSecret Configuration",
|
||||||
|
inputConfig: testCase3Config,
|
||||||
|
assertErrFunc: require.Error,
|
||||||
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []test{testCase1, testCase2, testCase3} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
_, err := NewZitadelManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockZitadelCredentials struct {
|
||||||
|
jwtToken JWTToken
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
return mc.jwtToken, mc.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelRequestJWTToken(t *testing.T) {
|
||||||
|
|
||||||
|
type requestJWTTokenTest struct {
|
||||||
|
name string
|
||||||
|
inputCode int
|
||||||
|
inputRespBody string
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedFuncExitErrDiff error
|
||||||
|
expectedToken string
|
||||||
|
}
|
||||||
|
exp := 5
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||||
|
name: "Good JWT Response",
|
||||||
|
inputCode: 200,
|
||||||
|
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: token,
|
||||||
|
}
|
||||||
|
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||||
|
name: "Request Bad Status Code",
|
||||||
|
inputCode: 400,
|
||||||
|
inputRespBody: "{}",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
jwtReqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputRespBody,
|
||||||
|
code: testCase.inputCode,
|
||||||
|
}
|
||||||
|
config := ZitadelClientConfig{}
|
||||||
|
|
||||||
|
creds := ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: &jwtReqClient,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := creds.requestJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
if testCase.expectedFuncExitErrDiff != nil {
|
||||||
|
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err, "unable to read the response body")
|
||||||
|
|
||||||
|
jwtToken := JWTToken{}
|
||||||
|
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||||
|
assert.NoError(t, err, "unable to parse the json input")
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelParseRequestJWTResponse(t *testing.T) {
|
||||||
|
type parseRequestJWTResponseTest struct {
|
||||||
|
name string
|
||||||
|
inputRespBody string
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedToken string
|
||||||
|
expectedExpiresIn int
|
||||||
|
assertErrFunc assert.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
exp := 100
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||||
|
name: "Parse Good JWT Body",
|
||||||
|
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: token,
|
||||||
|
expectedExpiresIn: exp,
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "no error was expected",
|
||||||
|
}
|
||||||
|
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||||
|
name: "Parse Bad json JWT Body",
|
||||||
|
inputRespBody: "",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: "",
|
||||||
|
expectedExpiresIn: 0,
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "json error was expected",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||||
|
config := ZitadelClientConfig{}
|
||||||
|
|
||||||
|
creds := ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelJwtStillValid(t *testing.T) {
|
||||||
|
type jwtStillValidTest struct {
|
||||||
|
name string
|
||||||
|
inputTime time.Time
|
||||||
|
expectedResult bool
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||||
|
name: "JWT still valid",
|
||||||
|
inputTime: time.Now().Add(10 * time.Second),
|
||||||
|
expectedResult: true,
|
||||||
|
message: "should be true",
|
||||||
|
}
|
||||||
|
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||||
|
name: "JWT is invalid",
|
||||||
|
inputTime: time.Now(),
|
||||||
|
expectedResult: false,
|
||||||
|
message: "should be false",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
config := ZitadelClientConfig{}
|
||||||
|
|
||||||
|
creds := ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
}
|
||||||
|
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelAuthenticate(t *testing.T) {
|
||||||
|
type authenticateTest struct {
|
||||||
|
name string
|
||||||
|
inputCode int
|
||||||
|
inputResBody string
|
||||||
|
inputExpireToken time.Time
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedFuncExitErrDiff error
|
||||||
|
expectedCode int
|
||||||
|
expectedToken string
|
||||||
|
}
|
||||||
|
exp := 5
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
authenticateTestCase1 := authenticateTest{
|
||||||
|
name: "Get Cached token",
|
||||||
|
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: nil,
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateTestCase2 := authenticateTest{
|
||||||
|
name: "Get Good JWT Response",
|
||||||
|
inputCode: 200,
|
||||||
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateTestCase3 := authenticateTest{
|
||||||
|
name: "Get Bad Status Code",
|
||||||
|
inputCode: 400,
|
||||||
|
inputResBody: "{}",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
jwtReqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputResBody,
|
||||||
|
code: testCase.inputCode,
|
||||||
|
}
|
||||||
|
config := ZitadelClientConfig{}
|
||||||
|
|
||||||
|
creds := ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: &jwtReqClient,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||||
|
|
||||||
|
_, err := creds.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
if testCase.expectedFuncExitErrDiff != nil {
|
||||||
|
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
|
||||||
|
type updateUserAppMetadataTest struct {
|
||||||
|
name string
|
||||||
|
inputReqBody string
|
||||||
|
expectedReqBody string
|
||||||
|
appMetadata AppMetadata
|
||||||
|
statusCode int
|
||||||
|
helper ManagerHelper
|
||||||
|
managerCreds ManagerCredentials
|
||||||
|
assertErrFunc assert.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Authentication",
|
||||||
|
expectedReqBody: "",
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 400,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockZitadelCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
err: fmt.Errorf("error"),
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Response Parsing",
|
||||||
|
statusCode: 400,
|
||||||
|
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||||
|
managerCreds: &mockZitadelCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||||
|
name: "Good request",
|
||||||
|
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 200,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockZitadelCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
invite := true
|
||||||
|
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||||
|
name: "Update Pending Invite",
|
||||||
|
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
|
||||||
|
appMetadata: AppMetadata{
|
||||||
|
WTAccountID: "ok",
|
||||||
|
WTPendingInvite: &invite,
|
||||||
|
},
|
||||||
|
statusCode: 200,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockZitadelCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||||
|
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
reqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputReqBody,
|
||||||
|
code: testCase.statusCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
httpClient: &reqClient,
|
||||||
|
credentials: testCase.managerCreds,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelProfile(t *testing.T) {
|
||||||
|
type azureProfileTest struct {
|
||||||
|
name string
|
||||||
|
invite bool
|
||||||
|
inputProfile zitadelProfile
|
||||||
|
expectedUserData UserData
|
||||||
|
}
|
||||||
|
|
||||||
|
azureProfileTestCase1 := azureProfileTest{
|
||||||
|
name: "User Request",
|
||||||
|
invite: false,
|
||||||
|
inputProfile: zitadelProfile{
|
||||||
|
ID: "test1",
|
||||||
|
State: "USER_STATE_ACTIVE",
|
||||||
|
UserName: "test1@mail.com",
|
||||||
|
PreferredLoginName: "test1@mail.com",
|
||||||
|
LoginNames: []string{
|
||||||
|
"test1@mail.com",
|
||||||
|
},
|
||||||
|
Human: &zitadelUser{
|
||||||
|
Profile: zitadelUserInfo{
|
||||||
|
FirstName: "ZITADEL",
|
||||||
|
LastName: "Admin",
|
||||||
|
DisplayName: "ZITADEL Admin",
|
||||||
|
},
|
||||||
|
Email: zitadelEmail{
|
||||||
|
Email: "test1@mail.com",
|
||||||
|
IsEmailVerified: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Metadata: []zitadelMetadata{
|
||||||
|
{
|
||||||
|
Key: "wt_account_id",
|
||||||
|
Value: "MQ==",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "wt_pending_invite",
|
||||||
|
Value: "ZmFsc2U=",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUserData: UserData{
|
||||||
|
ID: "test1",
|
||||||
|
Name: "ZITADEL Admin",
|
||||||
|
Email: "test1@mail.com",
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
azureProfileTestCase2 := azureProfileTest{
|
||||||
|
name: "Service User Request",
|
||||||
|
invite: true,
|
||||||
|
inputProfile: zitadelProfile{
|
||||||
|
ID: "test2",
|
||||||
|
State: "USER_STATE_ACTIVE",
|
||||||
|
UserName: "machine",
|
||||||
|
PreferredLoginName: "machine",
|
||||||
|
LoginNames: []string{
|
||||||
|
"machine",
|
||||||
|
},
|
||||||
|
Human: nil,
|
||||||
|
Metadata: []zitadelMetadata{
|
||||||
|
{
|
||||||
|
Key: "wt_account_id",
|
||||||
|
Value: "MQ==",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "wt_pending_invite",
|
||||||
|
Value: "dHJ1ZQ==",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUserData: UserData{
|
||||||
|
ID: "test2",
|
||||||
|
Name: "machine",
|
||||||
|
Email: "machine",
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||||
|
userData := testCase.inputProfile.userData()
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -45,7 +49,8 @@ type Options struct {
|
|||||||
|
|
||||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||||
type Jwks struct {
|
type Jwks struct {
|
||||||
Keys []JSONWebKey `json:"keys"`
|
Keys []JSONWebKey `json:"keys"`
|
||||||
|
expiresInTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// JSONWebKey is a representation of a Jason Web Key
|
// JSONWebKey is a representation of a Jason Web Key
|
||||||
@@ -64,12 +69,13 @@ type JWTValidator struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewJWTValidator constructor
|
// NewJWTValidator constructor
|
||||||
func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) {
|
func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
|
||||||
keys, err := getPemKeys(keysLocation)
|
keys, err := getPemKeys(keysLocation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lock sync.Mutex
|
||||||
options := Options{
|
options := Options{
|
||||||
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
|
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
|
||||||
// Verify 'aud' claim
|
// Verify 'aud' claim
|
||||||
@@ -89,6 +95,23 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string)
|
|||||||
return token, errors.New("invalid issuer")
|
return token, errors.New("invalid issuer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If keys are rotated, verify the keys prior to token validation
|
||||||
|
if idpSignkeyRefreshEnabled {
|
||||||
|
// If the keys are invalid, retrieve new ones
|
||||||
|
if !keys.stillValid() {
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
refreshedKeys, err := getPemKeys(keysLocation)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||||
|
refreshedKeys = keys
|
||||||
|
}
|
||||||
|
|
||||||
|
keys = refreshedKeys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cert, err := getPemCert(token, keys)
|
cert, err := getPemCert(token, keys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -154,6 +177,11 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
|
|||||||
return parsedToken, nil
|
return parsedToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
|
||||||
|
func (jwks *Jwks) stillValid() bool {
|
||||||
|
return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
|
||||||
|
}
|
||||||
|
|
||||||
func getPemKeys(keysLocation string) (*Jwks, error) {
|
func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||||
resp, err := http.Get(keysLocation)
|
resp, err := http.Get(keysLocation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -167,6 +195,10 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
|
|||||||
return jwks, err
|
return jwks, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cacheControlHeader := resp.Header.Get("Cache-Control")
|
||||||
|
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
|
||||||
|
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
return jwks, err
|
return jwks, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,3 +280,26 @@ func convertExponentStringToInt(stringExponent string) (int, error) {
|
|||||||
|
|
||||||
return int(exponent), nil
|
return int(exponent), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
|
||||||
|
func getMaxAgeFromCacheHeader(cacheControl string) int {
|
||||||
|
// Split into individual directives
|
||||||
|
directives := strings.Split(cacheControl, ",")
|
||||||
|
|
||||||
|
for _, directive := range directives {
|
||||||
|
directive = strings.TrimSpace(directive)
|
||||||
|
if strings.HasPrefix(directive, "max-age=") {
|
||||||
|
// Extract the max-age value
|
||||||
|
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
|
||||||
|
maxAge, err := strconv.Atoi(maxAgeStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("error parsing max-age: %v", err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return maxAge
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,12 +15,11 @@ import (
|
|||||||
|
|
||||||
type MockAccountManager struct {
|
type MockAccountManager struct {
|
||||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||||
GetAccountByUserIDFunc func(userID string) (*server.Account, error)
|
|
||||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||||
IsUserAdminFunc func(userID string) (bool, error)
|
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
AccountExistsFunc func(accountId string) (*bool, error)
|
AccountExistsFunc func(accountId string) (*bool, error)
|
||||||
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
||||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||||
@@ -61,11 +60,11 @@ type MockAccountManager struct {
|
|||||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||||
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
|
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||||
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||||
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
|
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||||
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||||
GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
@@ -113,14 +112,6 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByUserID mock implementation of GetAccountByUserID from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) GetAccountByUserID(userID string) (*server.Account, error) {
|
|
||||||
if am.GetAccountByUserIDFunc != nil {
|
|
||||||
return am.GetAccountByUserIDFunc(userID)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserID is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
|
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
|
||||||
func (am *MockAccountManager) CreateSetupKey(
|
func (am *MockAccountManager) CreateSetupKey(
|
||||||
accountID string,
|
accountID string,
|
||||||
@@ -199,33 +190,33 @@ func (am *MockAccountManager) MarkPATUsed(pat string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
||||||
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||||
if am.CreatePATFunc != nil {
|
if am.CreatePATFunc != nil {
|
||||||
return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn)
|
return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
|
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
|
||||||
func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
if am.DeletePATFunc != nil {
|
if am.DeletePATFunc != nil {
|
||||||
return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID)
|
return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||||
if am.GetPATFunc != nil {
|
if am.GetPATFunc != nil {
|
||||||
return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID)
|
return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||||
if am.GetAllPATsFunc != nil {
|
if am.GetAllPATsFunc != nil {
|
||||||
return am.GetAllPATsFunc(accountID, executingUserID, targetUserID)
|
return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
|
||||||
}
|
}
|
||||||
@@ -394,12 +385,12 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
|
|||||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
|
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
|
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||||
func (am *MockAccountManager) IsUserAdmin(userID string) (bool, error) {
|
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||||
if am.IsUserAdminFunc != nil {
|
if am.GetUserFunc != nil {
|
||||||
return am.IsUserAdminFunc(userID)
|
return am.GetUserFunc(claims)
|
||||||
}
|
}
|
||||||
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
||||||
@@ -502,9 +493,9 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
|
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
if am.DeleteUserFunc != nil {
|
if am.DeleteUserFunc != nil {
|
||||||
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
|
return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -605,6 +605,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er
|
|||||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if peerLoginExpired(peer, account) {
|
if peerLoginExpired(peer, account) {
|
||||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||||
}
|
}
|
||||||
@@ -644,6 +649,11 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
|||||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
updateRemotePeers := false
|
updateRemotePeers := false
|
||||||
if peerLoginExpired(peer, account) {
|
if peerLoginExpired(peer, account) {
|
||||||
err = checkAuth(login.UserID, peer)
|
err = checkAuth(login.UserID, peer)
|
||||||
@@ -676,6 +686,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
|||||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error {
|
||||||
|
if peer.AddedWithSSOLogin() {
|
||||||
|
user, err := account.FindUser(peer.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.PermissionDenied, "user doesn't exist")
|
||||||
|
}
|
||||||
|
if user.IsBlocked() {
|
||||||
|
return status.Errorf(status.PermissionDenied, "user is blocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func checkAuth(loginUserID string, peer *Peer) error {
|
func checkAuth(loginUserID string, peer *Peer) error {
|
||||||
if loginUserID == "" {
|
if loginUserID == "" {
|
||||||
// absence of a user ID indicates that JWT wasn't provided.
|
// absence of a user ID indicates that JWT wasn't provided.
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,15 +51,22 @@ type User struct {
|
|||||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||||
AutoGroups []string
|
AutoGroups []string
|
||||||
PATs map[string]*PersonalAccessToken
|
PATs map[string]*PersonalAccessToken
|
||||||
|
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||||
|
Blocked bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAdmin returns true if user is an admin, false otherwise
|
// IsBlocked returns true if the user is blocked, false otherwise
|
||||||
|
func (u *User) IsBlocked() bool {
|
||||||
|
return u.Blocked
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAdmin returns true if the user is an admin, false otherwise
|
||||||
func (u *User) IsAdmin() bool {
|
func (u *User) IsAdmin() bool {
|
||||||
return u.Role == UserRoleAdmin
|
return u.Role == UserRoleAdmin
|
||||||
}
|
}
|
||||||
|
|
||||||
// toUserInfo converts a User object to a UserInfo object.
|
// ToUserInfo converts a User object to a UserInfo object.
|
||||||
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||||
autoGroups := u.AutoGroups
|
autoGroups := u.AutoGroups
|
||||||
if autoGroups == nil {
|
if autoGroups == nil {
|
||||||
autoGroups = []string{}
|
autoGroups = []string{}
|
||||||
@@ -73,6 +81,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
|||||||
AutoGroups: u.AutoGroups,
|
AutoGroups: u.AutoGroups,
|
||||||
Status: string(UserStatusActive),
|
Status: string(UserStatusActive),
|
||||||
IsServiceUser: u.IsServiceUser,
|
IsServiceUser: u.IsServiceUser,
|
||||||
|
IsBlocked: u.Blocked,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
if userData.ID != u.Id {
|
if userData.ID != u.Id {
|
||||||
@@ -92,6 +101,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
|||||||
AutoGroups: autoGroups,
|
AutoGroups: autoGroups,
|
||||||
Status: string(userStatus),
|
Status: string(userStatus),
|
||||||
IsServiceUser: u.IsServiceUser,
|
IsServiceUser: u.IsServiceUser,
|
||||||
|
IsBlocked: u.Blocked,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,6 +122,7 @@ func (u *User) Copy() *User {
|
|||||||
IsServiceUser: u.IsServiceUser,
|
IsServiceUser: u.IsServiceUser,
|
||||||
ServiceUserName: u.ServiceUserName,
|
ServiceUserName: u.ServiceUserName,
|
||||||
PATs: pats,
|
PATs: pats,
|
||||||
|
Blocked: u.Blocked,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,7 +148,7 @@ func NewAdminUser(id string) *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createServiceUser creates a new service user under the given account.
|
// createServiceUser creates a new service user under the given account.
|
||||||
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -146,7 +157,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
|
|||||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if executingUser == nil {
|
if executingUser == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
@@ -165,7 +176,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": newUser.ServiceUserName}
|
meta := map[string]any{"name": newUser.ServiceUserName}
|
||||||
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
am.storeEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||||
|
|
||||||
return &UserInfo{
|
return &UserInfo{
|
||||||
ID: newUser.Id,
|
ID: newUser.Id,
|
||||||
@@ -211,7 +222,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
|||||||
}
|
}
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||||
}
|
}
|
||||||
|
|
||||||
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
||||||
@@ -220,7 +231,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(users) > 0 {
|
if len(users) > 0 {
|
||||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||||
}
|
}
|
||||||
|
|
||||||
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
||||||
@@ -248,12 +259,27 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
|||||||
|
|
||||||
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
|
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
|
||||||
|
|
||||||
return newUser.toUserInfo(idpUser)
|
return newUser.ToUserInfo(idpUser)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUser looks up a user by provided authorization claims.
|
||||||
|
// It will also create an account if didn't exist for this user before.
|
||||||
|
func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||||
|
account, _, err := am.GetAccountFromToken(claims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := account.Users[claims.UserId]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteUser deletes a user from the given account.
|
// DeleteUser deletes a user from the given account.
|
||||||
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
|
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -267,7 +293,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
|||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if executingUser == nil {
|
if executingUser == nil {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
@@ -280,7 +306,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": targetUser.ServiceUserName}
|
meta := map[string]any{"name": targetUser.ServiceUserName}
|
||||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||||
|
|
||||||
delete(account.Users, targetUserID)
|
delete(account.Users, targetUserID)
|
||||||
|
|
||||||
@@ -293,7 +319,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreatePAT creates a new PAT for the given user
|
// CreatePAT creates a new PAT for the given user
|
||||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -315,12 +341,12 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
|||||||
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if targetUser == nil {
|
if targetUser == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,13 +363,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||||
|
|
||||||
return pat, nil
|
return pat, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePAT deletes a specific PAT from a user
|
// DeletePAT deletes a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -357,12 +383,12 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
|||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if targetUser == nil {
|
if targetUser == nil {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,7 +407,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||||
|
|
||||||
delete(targetUser.PATs, tokenID)
|
delete(targetUser.PATs, tokenID)
|
||||||
|
|
||||||
@@ -393,7 +419,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPAT returns a specific PAT from a user
|
// GetPAT returns a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -407,12 +433,12 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
|||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if targetUser == nil {
|
if targetUser == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,7 +451,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPATs returns all PATs for a user
|
// GetAllPATs returns all PATs for a user
|
||||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -439,12 +465,12 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
|||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if targetUser == nil {
|
if targetUser == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -456,9 +482,9 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
|||||||
return pats, nil
|
return pats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
|
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
|
||||||
// Only User.AutoGroups field is allowed to be updated for now.
|
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||||
func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) {
|
func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -471,56 +497,102 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
initiatorUser, err := account.FindUser(initiatorUserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !initiatorUser.IsAdmin() || initiatorUser.IsBlocked() {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "only admins are authorized to perform user update operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldUser := account.Users[update.Id]
|
||||||
|
if oldUser == nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||||
|
}
|
||||||
|
|
||||||
|
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && update.Role != UserRoleAdmin {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||||
|
}
|
||||||
|
|
||||||
|
// only auto groups, revoked status, and name can be updated for now
|
||||||
|
newUser := oldUser.Copy()
|
||||||
|
newUser.Role = update.Role
|
||||||
|
newUser.Blocked = update.Blocked
|
||||||
|
|
||||||
for _, newGroupID := range update.AutoGroups {
|
for _, newGroupID := range update.AutoGroups {
|
||||||
if _, ok := account.Groups[newGroupID]; !ok {
|
if _, ok := account.Groups[newGroupID]; !ok {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||||
newGroupID, update.Id)
|
newGroupID, update.Id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
oldUser := account.Users[update.Id]
|
|
||||||
if oldUser == nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "update not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// only auto groups, revoked status, and name can be updated for now
|
|
||||||
newUser := oldUser.Copy()
|
|
||||||
newUser.AutoGroups = update.AutoGroups
|
newUser.AutoGroups = update.AutoGroups
|
||||||
newUser.Role = update.Role
|
|
||||||
|
|
||||||
account.Users[newUser.Id] = newUser
|
account.Users[newUser.Id] = newUser
|
||||||
|
|
||||||
|
if !oldUser.IsBlocked() && update.IsBlocked() {
|
||||||
|
// expire peers that belong to the user who's getting blocked
|
||||||
|
blockedPeers, err := account.FindUserPeers(update.Id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var peerIDs []string
|
||||||
|
for _, peer := range blockedPeers {
|
||||||
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
|
peer.MarkLoginExpired(true)
|
||||||
|
account.UpdatePeer(peer)
|
||||||
|
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||||
|
err = am.updateAccountPeers(account)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
|
||||||
|
return nil, err
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveAccount(account); err != nil {
|
if err = am.Store.SaveAccount(account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
// store activity logs
|
||||||
if oldUser.Role != newUser.Role {
|
if oldUser.Role != newUser.Role {
|
||||||
am.storeEvent(userID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||||
}
|
}
|
||||||
|
|
||||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
if update.AutoGroups != nil {
|
||||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||||
for _, g := range removedGroups {
|
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||||
group := account.GetGroup(g)
|
for _, g := range removedGroups {
|
||||||
if group != nil {
|
group := account.GetGroup(g)
|
||||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
if group != nil {
|
||||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||||
} else {
|
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
} else {
|
||||||
|
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
for _, g := range addedGroups {
|
||||||
|
group := account.GetGroup(g)
|
||||||
for _, g := range addedGroups {
|
if group != nil {
|
||||||
group := account.GetGroup(g)
|
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||||
if group != nil {
|
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
}
|
||||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
|
||||||
} else {
|
|
||||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
||||||
@@ -531,9 +603,9 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
|||||||
if userData == nil {
|
if userData == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
|
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
|
||||||
}
|
}
|
||||||
return newUser.toUserInfo(userData)
|
return newUser.ToUserInfo(userData)
|
||||||
}
|
}
|
||||||
return newUser.toUserInfo(nil)
|
return newUser.ToUserInfo(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||||
@@ -573,26 +645,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByUserID returns an existing account for a given user id
|
|
||||||
func (am *DefaultAccountManager) GetAccountByUserID(userID string) (*Account, error) {
|
|
||||||
return am.Store.GetAccountByUser(userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
|
|
||||||
func (am *DefaultAccountManager) IsUserAdmin(userID string) (bool, error) {
|
|
||||||
account, err := am.GetAccountByUserID(userID)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get account: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user, ok := account.Users[userID]
|
|
||||||
if !ok {
|
|
||||||
return false, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return user.Role == UserRoleAdmin, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
|
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
|
||||||
// based on provided user role.
|
// based on provided user role.
|
||||||
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
|
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
|
||||||
@@ -629,7 +681,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
|||||||
// if user is not an admin then show only current user and do not show other users
|
// if user is not an admin then show only current user and do not show other users
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
info, err := accountUser.toUserInfo(nil)
|
info, err := accountUser.ToUserInfo(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -646,7 +698,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
|||||||
|
|
||||||
var info *UserInfo
|
var info *UserInfo
|
||||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||||
info, err = localUser.toUserInfo(queriedUser)
|
info, err = localUser.ToUserInfo(queriedUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ import (
|
|||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -264,6 +266,7 @@ func TestUser_Copy(t *testing.T) {
|
|||||||
LastUsed: time.Now(),
|
LastUsed: time.Now(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Blocked: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := validateStruct(user)
|
err := validateStruct(user)
|
||||||
@@ -287,7 +290,7 @@ func validateStruct(s interface{}) (err error) {
|
|||||||
field := structVal.Field(i)
|
field := structVal.Field(i)
|
||||||
fieldName := structType.Field(i).Name
|
fieldName := structType.Field(i).Name
|
||||||
|
|
||||||
isSet := field.IsValid() && !field.IsZero()
|
isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool")
|
||||||
|
|
||||||
if !isSet {
|
if !isSet {
|
||||||
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
|
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
|
||||||
@@ -439,7 +442,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
|||||||
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||||
|
|
||||||
@@ -453,38 +456,27 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
|||||||
eventStore: &activity.InMemoryEventStore{},
|
eventStore: &activity.InMemoryEventStore{},
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := am.IsUserAdmin(mockUserID)
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
|
UserId: mockUserID,
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := am.GetUser(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when checking user role: %s", err)
|
t.Fatalf("Error when checking user role: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.Equal(t, mockUserID, user.Id)
|
||||||
|
assert.True(t, user.IsAdmin())
|
||||||
|
assert.False(t, user.IsBlocked())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
|
func TestUser_IsAdmin(t *testing.T) {
|
||||||
store := newStore(t)
|
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
|
||||||
account.Users[mockUserID] = &User{
|
|
||||||
Id: mockUserID,
|
|
||||||
Role: "user",
|
|
||||||
}
|
|
||||||
|
|
||||||
err := store.SaveAccount(account)
|
user := NewAdminUser(mockUserID)
|
||||||
if err != nil {
|
assert.True(t, user.IsAdmin())
|
||||||
t.Fatalf("Error when saving account: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
am := DefaultAccountManager{
|
user = NewRegularUser(mockUserID)
|
||||||
Store: store,
|
assert.False(t, user.IsAdmin())
|
||||||
eventStore: &activity.InMemoryEventStore{},
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := am.IsUserAdmin(mockUserID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error when checking user role: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||||
@@ -541,3 +533,103 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
|||||||
assert.Equal(t, 1, len(users))
|
assert.Equal(t, 1, len(users))
|
||||||
assert.Equal(t, mockServiceUserID, users[0].ID)
|
assert.Equal(t, mockServiceUserID, users[0].ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
regularUserID := "regularUser"
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
adminInitiator bool
|
||||||
|
update *User
|
||||||
|
expectedErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should_Fail_To_Update_Admin_Role",
|
||||||
|
expectedErr: true,
|
||||||
|
adminInitiator: true,
|
||||||
|
update: &User{
|
||||||
|
Id: userID,
|
||||||
|
Role: UserRoleUser,
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "Should_Fail_When_Admin_Blocks_Themselves",
|
||||||
|
expectedErr: true,
|
||||||
|
adminInitiator: true,
|
||||||
|
update: &User{
|
||||||
|
Id: userID,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should_Fail_To_Update_Non_Existing_User",
|
||||||
|
expectedErr: true,
|
||||||
|
adminInitiator: true,
|
||||||
|
update: &User{
|
||||||
|
Id: userID,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
|
||||||
|
expectedErr: true,
|
||||||
|
adminInitiator: false,
|
||||||
|
update: &User{
|
||||||
|
Id: userID,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should_Update_User",
|
||||||
|
expectedErr: false,
|
||||||
|
adminInitiator: true,
|
||||||
|
update: &User{
|
||||||
|
Id: regularUserID,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
|
||||||
|
// create an account and an admin user
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a regular user
|
||||||
|
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
||||||
|
err = manager.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
initiatorID := userID
|
||||||
|
if !tc.adminInitiator {
|
||||||
|
initiatorID = regularUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := manager.SaveUser(account.Id, initiatorID, tc.update)
|
||||||
|
if tc.expectedErr {
|
||||||
|
require.Errorf(t, err, "expecting SaveUser to throw an error")
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err, "expecting SaveUser not to throw an error")
|
||||||
|
assert.NotNil(t, updated)
|
||||||
|
|
||||||
|
assert.Equal(t, string(tc.update.Role), updated.Role)
|
||||||
|
assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ download_release_binary() {
|
|||||||
echo "Installing $1 from $DOWNLOAD_URL"
|
echo "Installing $1 from $DOWNLOAD_URL"
|
||||||
cd /tmp && curl -LO "$DOWNLOAD_URL"
|
cd /tmp && curl -LO "$DOWNLOAD_URL"
|
||||||
|
|
||||||
|
|
||||||
if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then
|
if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then
|
||||||
INSTALL_DIR="/Applications/NetBird UI.app"
|
INSTALL_DIR="/Applications/NetBird UI.app"
|
||||||
|
|
||||||
@@ -43,8 +44,9 @@ download_release_binary() {
|
|||||||
unzip -q -o "$BINARY_NAME"
|
unzip -q -o "$BINARY_NAME"
|
||||||
mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR"
|
mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR"
|
||||||
else
|
else
|
||||||
|
sudo mkdir -p "$INSTALL_DIR"
|
||||||
tar -xzvf "$BINARY_NAME"
|
tar -xzvf "$BINARY_NAME"
|
||||||
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR"
|
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,4 +283,4 @@ install_netbird() {
|
|||||||
echo "sudo netbird up"
|
echo "sudo netbird up"
|
||||||
}
|
}
|
||||||
|
|
||||||
install_netbird
|
install_netbird
|
||||||
|
|||||||
11
sharedsock/filter.go
Normal file
11
sharedsock/filter.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package sharedsock
|
||||||
|
|
||||||
|
import "golang.org/x/net/bpf"
|
||||||
|
|
||||||
|
const magicCookie uint32 = 0x2112A442
|
||||||
|
|
||||||
|
// BPFFilter is a generic filter that provides ipv4 and ipv6 BPF instructions
|
||||||
|
type BPFFilter interface {
|
||||||
|
// GetInstructions returns raw BPF instructions for ipv4 and ipv6
|
||||||
|
GetInstructions(port uint32) (ipv4 []bpf.RawInstruction, ipv6 []bpf.RawInstruction, err error)
|
||||||
|
}
|
||||||
47
sharedsock/filter_linux.go
Normal file
47
sharedsock/filter_linux.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package sharedsock
|
||||||
|
|
||||||
|
import "golang.org/x/net/bpf"
|
||||||
|
|
||||||
|
// IncomingSTUNFilter implements BPFFilter and filters out anything but incoming STUN packets to a specified destination port.
|
||||||
|
// Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard).
|
||||||
|
type IncomingSTUNFilter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIncomingSTUNFilter creates an instance of a IncomingSTUNFilter
|
||||||
|
func NewIncomingSTUNFilter() BPFFilter {
|
||||||
|
return &IncomingSTUNFilter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInstructions returns raw BPF instructions for ipv4 and ipv6 that filter out anything but STUN packets
|
||||||
|
func (filter *IncomingSTUNFilter) GetInstructions(dstPort uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) {
|
||||||
|
raw4, err = rawInstructions(22, 32, dstPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
raw6, err = rawInstructions(2, 12, dstPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return raw4, raw6, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawInstructions(dstPortOff, cookieOff, dstPort uint32) ([]bpf.RawInstruction, error) {
|
||||||
|
// UDP raw socket for ipv4 receives the rcvdPacket with IP headers
|
||||||
|
// UDP raw socket for ipv6 receives the rcvdPacket with UDP headers
|
||||||
|
instructions := []bpf.Instruction{
|
||||||
|
// Load the destination port from the UDP header (offset 22 for ipv4 and 2 for ipv6)
|
||||||
|
bpf.LoadAbsolute{Off: dstPortOff, Size: 2},
|
||||||
|
// Check if the destination port is equal to the specified `dstPort`. If not, skip the next 3 instructions.
|
||||||
|
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: dstPort, SkipTrue: 3},
|
||||||
|
// Load the 4-byte value (magic cookie) from the UDP payload (offset 32 for ipv4 and 12 for ipv6)
|
||||||
|
bpf.LoadAbsolute{Off: cookieOff, Size: 4},
|
||||||
|
// Check if the loaded value is equal to the `magicCookie`. If not, skip the next instruction.
|
||||||
|
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: magicCookie, SkipTrue: 1},
|
||||||
|
// If both the dstPort and the magic cookie match, return a positive value (0xffffffff)
|
||||||
|
bpf.RetConstant{Val: 0xffffffff},
|
||||||
|
// If either the dstPort or the magic cookie doesn't match, return 0
|
||||||
|
bpf.RetConstant{Val: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpf.Assemble(instructions)
|
||||||
|
}
|
||||||
8
sharedsock/filter_nolinux.go
Normal file
8
sharedsock/filter_nolinux.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package sharedsock
|
||||||
|
|
||||||
|
// NewIncomingSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux
|
||||||
|
func NewIncomingSTUNFilter() BPFFilter {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
328
sharedsock/sock_linux.go
Normal file
328
sharedsock/sock_linux.go
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
// Inspired by
|
||||||
|
// Jason Donenfeld (https://git.zx2c4.com/wireguard-tools/tree/contrib/nat-hole-punching/nat-punch-client.c#n96)
|
||||||
|
// and @stv0g in https://github.com/stv0g/cunicu/tree/ebpf-poc/ebpf_poc
|
||||||
|
|
||||||
|
package sharedsock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/google/gopacket/routing"
|
||||||
|
"github.com/libp2p/go-netroute"
|
||||||
|
"github.com/mdlayher/socket"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrSharedSockStopped indicates that shared socket has been stopped
|
||||||
|
var ErrSharedSockStopped = fmt.Errorf("shared socked stopped")
|
||||||
|
|
||||||
|
// SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered
|
||||||
|
// by BPF instructions (e.g., IncomingSTUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)).
|
||||||
|
// It is meant to be used when sharing a port with some other process.
|
||||||
|
type SharedSocket struct {
|
||||||
|
ctx context.Context
|
||||||
|
conn4 *socket.Conn
|
||||||
|
conn6 *socket.Conn
|
||||||
|
port int
|
||||||
|
routerMux sync.RWMutex
|
||||||
|
router routing.Router
|
||||||
|
packetDemux chan rcvdPacket
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type rcvdPacket struct {
|
||||||
|
n int
|
||||||
|
addr unix.Sockaddr
|
||||||
|
buf []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type receiver func(ctx context.Context, p []byte, flags int) (int, unix.Sockaddr, error)
|
||||||
|
|
||||||
|
var writeSerializerOptions = gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines
|
||||||
|
func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
|
||||||
|
var err error
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
rawSock := &SharedSocket{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
port: port,
|
||||||
|
packetDemux: make(chan rcvdPacket),
|
||||||
|
}
|
||||||
|
|
||||||
|
rawSock.router, err = netroute.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create router: %rawSock", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("socket.Socket for ipv4 failed with: %rawSock", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawSock.conn6, err = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("socket.Socket for ipv6 failed with: %rawSock", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port))
|
||||||
|
if err != nil {
|
||||||
|
_ = rawSock.Close()
|
||||||
|
return nil, fmt.Errorf("getBPFInstructions failed with: %rawSock", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rawSock.conn4.SetBPF(ipv4Instructions)
|
||||||
|
if err != nil {
|
||||||
|
_ = rawSock.Close()
|
||||||
|
return nil, fmt.Errorf("socket4.SetBPF failed with: %rawSock", err)
|
||||||
|
}
|
||||||
|
if rawSock.conn6 != nil {
|
||||||
|
err = rawSock.conn6.SetBPF(ipv6Instructions)
|
||||||
|
if err != nil {
|
||||||
|
_ = rawSock.Close()
|
||||||
|
return nil, fmt.Errorf("socket6.SetBPF failed with: %rawSock", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go rawSock.read(rawSock.conn4.Recvfrom)
|
||||||
|
if rawSock.conn6 != nil {
|
||||||
|
go rawSock.read(rawSock.conn6.Recvfrom)
|
||||||
|
}
|
||||||
|
|
||||||
|
go rawSock.updateRouter()
|
||||||
|
|
||||||
|
return rawSock, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRouter updates the listener routing table client
|
||||||
|
// this is needed to avoid outdated information across different client networks
|
||||||
|
func (s *SharedSocket) updateRouter() {
|
||||||
|
ticker := time.NewTicker(15 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
router, err := netroute.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create and update packet router for stunListener: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.routerMux.Lock()
|
||||||
|
s.router = router
|
||||||
|
s.routerMux.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns an IPv4 address using the supplied port
|
||||||
|
func (s *SharedSocket) LocalAddr() net.Addr {
|
||||||
|
// todo check impact on ipv6 discovery
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: net.IPv4zero,
|
||||||
|
Port: s.port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets both the read and write deadlines associated with the ipv4 and ipv6 Conn sockets
|
||||||
|
func (s *SharedSocket) SetDeadline(t time.Time) error {
|
||||||
|
err := s.conn4.SetDeadline(t)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("s.conn4.SetDeadline error: %s", err)
|
||||||
|
}
|
||||||
|
if s.conn6 == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.conn6.SetDeadline(t)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("s.conn6.SetDeadline error: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the read deadline associated with the ipv4 and ipv6 Conn sockets
|
||||||
|
func (s *SharedSocket) SetReadDeadline(t time.Time) error {
|
||||||
|
err := s.conn4.SetReadDeadline(t)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("s.conn4.SetReadDeadline error: %s", err)
|
||||||
|
}
|
||||||
|
if s.conn6 == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.conn6.SetReadDeadline(t)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("s.conn6.SetReadDeadline error: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the write deadline associated with the ipv4 and ipv6 Conn sockets
|
||||||
|
func (s *SharedSocket) SetWriteDeadline(t time.Time) error {
|
||||||
|
err := s.conn4.SetWriteDeadline(t)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("s.conn4.SetWriteDeadline error: %s", err)
|
||||||
|
}
|
||||||
|
if s.conn6 == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.conn6.SetWriteDeadline(t)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("s.conn6.SetWriteDeadline error: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying ipv4 and ipv6 conn sockets
|
||||||
|
func (s *SharedSocket) Close() error {
|
||||||
|
s.cancel()
|
||||||
|
errGrp := errgroup.Group{}
|
||||||
|
if s.conn4 != nil {
|
||||||
|
errGrp.Go(s.conn4.Close)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.conn6 != nil {
|
||||||
|
errGrp.Go(s.conn6.Close)
|
||||||
|
}
|
||||||
|
return errGrp.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// read start a read loop for a specific receiver and sends the packet to the packetDemux channel
|
||||||
|
func (s *SharedSocket) read(receiver receiver) {
|
||||||
|
for {
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
n, addr, err := receiver(s.ctx, buf, 0)
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
case s.packetDemux <- rcvdPacket{n, addr, buf[:n], err}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFrom reads packets received in the packetDemux channel
|
||||||
|
func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||||
|
var pkt rcvdPacket
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return -1, nil, ErrSharedSockStopped
|
||||||
|
case pkt = <-s.packetDemux:
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkt.err != nil {
|
||||||
|
return -1, nil, pkt.err
|
||||||
|
}
|
||||||
|
var ip4layer layers.IPv4
|
||||||
|
var udp layers.UDP
|
||||||
|
var payload gopacket.Payload
|
||||||
|
var parser *gopacket.DecodingLayerParser
|
||||||
|
var ip net.IP
|
||||||
|
|
||||||
|
if sa, isIPv4 := pkt.addr.(*unix.SockaddrInet4); isIPv4 {
|
||||||
|
ip = sa.Addr[:]
|
||||||
|
parser = gopacket.NewDecodingLayerParser(layers.LayerTypeIPv4, &ip4layer, &udp, &payload)
|
||||||
|
} else if sa, isIPv6 := pkt.addr.(*unix.SockaddrInet6); isIPv6 {
|
||||||
|
ip = sa.Addr[:]
|
||||||
|
parser = gopacket.NewDecodingLayerParser(layers.LayerTypeUDP, &udp, &payload)
|
||||||
|
} else {
|
||||||
|
return -1, nil, fmt.Errorf("received invalid address family")
|
||||||
|
}
|
||||||
|
|
||||||
|
decodedLayers := make([]gopacket.LayerType, 0, 3)
|
||||||
|
|
||||||
|
err = parser.DecodeLayers(pkt.buf[:], &decodedLayers)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr := &net.UDPAddr{
|
||||||
|
IP: ip,
|
||||||
|
Port: int(udp.SrcPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(b, payload)
|
||||||
|
return int(udp.Length), remoteAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo builds a UDP packet and writes it using the specific IP version writter
|
||||||
|
func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||||
|
rUDPAddr, ok := rAddr.(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
return -1, fmt.Errorf("invalid address type")
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := gopacket.NewSerializeBuffer()
|
||||||
|
payload := gopacket.Payload(buf)
|
||||||
|
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(s.port),
|
||||||
|
DstPort: layers.UDPPort(rUDPAddr.Port),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.routerMux.RLock()
|
||||||
|
defer s.routerMux.RUnlock()
|
||||||
|
|
||||||
|
_, _, src, err := s.router.Route(rUDPAddr.IP)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("got an error while checking route, err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
|
||||||
|
|
||||||
|
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
|
||||||
|
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil {
|
||||||
|
return -1, fmt.Errorf("failed serialize rcvdPacket: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bufser := buffer.Bytes()
|
||||||
|
|
||||||
|
return 0, conn.Sendto(context.TODO(), bufser, 0, rSockAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWriterObjects returns the specific IP version objects that are used to build a packet and send it using the raw socket
|
||||||
|
func (s *SharedSocket) getWriterObjects(src, dest net.IP) (sa unix.Sockaddr, conn *socket.Conn, layer gopacket.NetworkLayer) {
|
||||||
|
if dest.To4() == nil {
|
||||||
|
sa = &unix.SockaddrInet6{}
|
||||||
|
copy(sa.(*unix.SockaddrInet6).Addr[:], dest.To16())
|
||||||
|
conn = s.conn6
|
||||||
|
|
||||||
|
layer = &layers.IPv6{
|
||||||
|
SrcIP: src,
|
||||||
|
DstIP: dest,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sa = &unix.SockaddrInet4{}
|
||||||
|
copy(sa.(*unix.SockaddrInet4).Addr[:], dest.To4())
|
||||||
|
conn = s.conn4
|
||||||
|
layer = &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
SrcIP: src,
|
||||||
|
DstIP: dest,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sa, conn, layer
|
||||||
|
}
|
||||||
162
sharedsock/sock_linux_test.go
Normal file
162
sharedsock/sock_linux_test.go
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
package sharedsock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/stun"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShouldReadSTUNOnReadFrom(t *testing.T) {
|
||||||
|
|
||||||
|
// create raw socket on a port
|
||||||
|
testingPort := 51821
|
||||||
|
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||||
|
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||||
|
err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||||
|
require.NoError(t, err, "unable to set deadline, error: %s", err)
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
// when reading from the raw socket
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
rcvMSG := &stun.Message{
|
||||||
|
Raw: buf,
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_, _, err := rawSock.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error while reading packet %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rcvMSG.Decode()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("error while parsing STUN message. The packet doesn't seem to be a STUN packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
|
// and sending STUN packet to the shared port, the packet has to be handled
|
||||||
|
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{Port: 12345, IP: net.ParseIP("127.0.0.1")})
|
||||||
|
require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
|
||||||
|
defer udpListener.Close()
|
||||||
|
stunMSG, err := stun.Build(stun.NewType(stun.MethodBinding, stun.ClassRequest), stun.TransactionID,
|
||||||
|
stun.Fingerprint,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "unable to build stun msg, error: %s", err)
|
||||||
|
_, err = udpListener.WriteTo(stunMSG.Raw, net.UDPAddrFromAddrPort(netip.MustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", testingPort))))
|
||||||
|
require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)
|
||||||
|
|
||||||
|
// the packet has to be handled and be a STUN packet
|
||||||
|
wg.Wait()
|
||||||
|
require.EqualValues(t, stunMSG.TransactionID, rcvMSG.TransactionID, "transaction id values did't match")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldNotReadNonSTUNPackets(t *testing.T) {
|
||||||
|
testingPort := 39439
|
||||||
|
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||||
|
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||||
|
defer rawSock.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
err = rawSock.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
require.NoError(t, err, "unable to set deadline, error: %s", err)
|
||||||
|
|
||||||
|
errGrp := errgroup.Group{}
|
||||||
|
errGrp.Go(func() error {
|
||||||
|
_, _, err := rawSock.ReadFrom(buf)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
nonStun := []byte("netbird")
|
||||||
|
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0, IP: net.ParseIP("127.0.0.1")})
|
||||||
|
require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
|
||||||
|
defer udpListener.Close()
|
||||||
|
remote := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", testingPort)))
|
||||||
|
_, err = udpListener.WriteTo(nonStun, remote)
|
||||||
|
require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)
|
||||||
|
|
||||||
|
err = errGrp.Wait()
|
||||||
|
require.Error(t, err, "should receive an error")
|
||||||
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
t.Errorf("error should be I/O timeout, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteTo(t *testing.T) {
|
||||||
|
udpListener, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 0, IP: net.ParseIP("127.0.0.1")})
|
||||||
|
require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
|
||||||
|
defer udpListener.Close()
|
||||||
|
|
||||||
|
testingPort := 39440
|
||||||
|
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||||
|
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||||
|
defer rawSock.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
err = udpListener.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||||
|
require.NoError(t, err, "unable to set deadline, error: %s", err)
|
||||||
|
|
||||||
|
errGrp := errgroup.Group{}
|
||||||
|
var remoteAdr net.Addr
|
||||||
|
var rcvBytes int
|
||||||
|
errGrp.Go(func() error {
|
||||||
|
n, a, err := udpListener.ReadFrom(buf)
|
||||||
|
remoteAdr = a
|
||||||
|
rcvBytes = n
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
msg := []byte("netbird")
|
||||||
|
_, err = rawSock.WriteTo(msg, udpListener.LocalAddr())
|
||||||
|
require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)
|
||||||
|
|
||||||
|
err = errGrp.Wait()
|
||||||
|
require.NoError(t, err, "received an error while reading the packet, error: %s", err)
|
||||||
|
|
||||||
|
require.EqualValues(t, string(msg), string(buf[:rcvBytes]), "received message should match")
|
||||||
|
|
||||||
|
udpRcv, ok := remoteAdr.(*net.UDPAddr)
|
||||||
|
require.True(t, ok, "udp address conversion didn't work")
|
||||||
|
|
||||||
|
require.EqualValues(t, testingPort, udpRcv.Port, "received address port didn't match")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSharedSocket_Close(t *testing.T) {
|
||||||
|
rawSock, err := Listen(39440, NewIncomingSTUNFilter())
|
||||||
|
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||||
|
|
||||||
|
errGrp := errgroup.Group{}
|
||||||
|
|
||||||
|
errGrp.Go(func() error {
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
_, _, err := rawSock.ReadFrom(buf)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
_ = rawSock.Close()
|
||||||
|
err = errGrp.Wait()
|
||||||
|
if err != ErrSharedSockStopped {
|
||||||
|
t.Errorf("invalid error response: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
14
sharedsock/sock_nolinux.go
Normal file
14
sharedsock/sock_nolinux.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package sharedsock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Listen is not supported on other platforms
|
||||||
|
func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
|
||||||
|
return nil, fmt.Errorf(fmt.Sprintf("Not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS))
|
||||||
|
}
|
||||||
@@ -350,7 +350,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
|||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
|
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||||
|
|
||||||
decryptedMessage, err := c.decryptMessage(msg)
|
decryptedMessage, err := c.decryptMessage(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user