Compare commits

...

25 Commits

Author SHA1 Message Date
Zoltan Papp
7b43d7e8ef Delete proxy package
The proxy package contains duplicated wg
configuration logic
2023-05-12 16:08:35 +02:00
Zoltan Papp
dcc83c8741 Replace the "recv a new msg from peer" debug log
to trace level
2023-05-12 11:07:41 +02:00
Zoltan Papp
d56669ec2e Remove unused dummy proxy 2023-05-12 11:05:59 +02:00
Misha Bragin
e3d2b6a408 Block user through HTTP API (#846)
The new functionality allows blocking a user in the Management service.
Blocked users lose access to the Dashboard, aren't able to modify the network map,
and all of their connected devices disconnect and are set to the "login expired" state.

Technically all above was achieved with the updated PUT /api/users endpoint,
that was extended with the is_blocked field.
2023-05-11 18:09:36 +02:00
Zoltan Papp
9f758b2015 Fix preshared key command line arg handling (#850) 2023-05-11 18:09:06 +02:00
Bethuel
2c50d7af1e Automatically load IdP OIDC configuration (#847) 2023-05-11 15:14:00 +02:00
pascal-fischer
e4c28f64fa Fix user cache lookup filtering for service users (#849) 2023-05-10 19:27:17 +02:00
Maycon Santos
6f2c4078ef Fix macOS installer script (#844)
Create /usr/local/bin/ folder before installation
2023-05-09 16:22:02 +02:00
Bethuel
f4ec1699ca Add Zitadel IdP (#833)
Added intergration with Zitadel management API.

Use the steps in zitadel.md for configuration.
2023-05-05 19:27:28 +02:00
Bethuel
fea53b2f0f Fix incomplete verification URI issue in device auth flow (#838)
Adds functionality to support Identity Provider (IdP) managers 
that do not support a complete verification URI in the 
device authentication flow. 
In cases where the verification_uri_complete field is empty,
the user will be prompted with their user_code, 
and the verification_uri  field will be used as a fallback
2023-05-05 12:43:04 +02:00
Zoltan Papp
60e6d0890a Fix sharedsock build on android (#837) 2023-05-05 10:55:23 +02:00
Misha Bragin
cb12e2da21 Correct sharedsock BPF fields (#835) 2023-05-04 12:28:32 +02:00
Bethuel
873b56f856 Add Azure Idp Manager (#822)
Added intergration with Azure IDP user API.

Use the steps in azure-ad.md for configuration:
cb03373f8f/docs/integrations/identity-providers/self-hosted/azure-ad.md
2023-05-03 14:51:44 +02:00
Maycon Santos
ecac82a5ae Share kernel Wireguard port with raw socket (#826)
This PR brings support of a shared port between stun (ICE agent) and
the kernel WireGuard

It implements a single port mode for execution with kernel WireGuard
interface using a raw socket listener.

BPF filters ensure that only STUN packets hit the NetBird userspace app

Removed a lot of the proxy logic and direct mode exchange.

Now we are doing an extra hole punch to the remote WireGuard 
port for best-effort cases and support to old client's direct mode.
2023-05-03 14:47:44 +02:00
pascal-fischer
59372ee159 API cleanup (#824)
removed all PATCH endpoints
updated path parameters for all endpoints
removed not implemented endpoints for api doc
minor description updates
2023-05-03 00:15:25 +02:00
pascal-fischer
08db5f5a42 Merge pull request #831 from netbirdio/fix/issue_with_account_creation_after_auth_refactor
FIx account creation issue after auth refactor
2023-05-02 19:14:54 +02:00
pascal-fischer
88678ef364 Merge pull request #808 from bcmmbaga/main
Add support for refreshing signing keys on expiry
2023-05-02 17:17:09 +02:00
Pascal Fischer
f1da4fd55d using old isAdmin function to create account 2023-05-02 16:49:29 +02:00
Bethuel
45224e76d0 fallback to olde keys if failing to fetch refreshed keys 2023-04-21 13:34:52 +03:00
Bethuel
90c8cfd863 synchronize access to the signing keys 2023-04-19 17:11:38 +03:00
Bethuel
f7196cd9a5 refactoring 2023-04-15 03:44:42 +03:00
Bethuel
53d78ad982 make variable unexported 2023-04-14 13:16:01 +03:00
Bethuel
9f352c1b7e validate keys for idp's with key rotation mechanism 2023-04-14 12:20:34 +03:00
Bethuel
a89808ecae initialize jwt validator with keys rotation state 2023-04-14 12:17:28 +03:00
Bethuel
c6190fa2ba add use-key-cache-headers flag to management command 2023-04-13 20:19:04 +03:00
67 changed files with 3777 additions and 2049 deletions

View File

@@ -72,6 +72,9 @@ jobs:
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
@@ -83,9 +86,13 @@ jobs:
- run: chmod +x *testing.bin
- name: Run Shared Sock tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -3,6 +3,7 @@ package cmd
import (
"context"
"fmt"
"strings"
"time"
"github.com/skratchdot/open-golang/open"
@@ -45,12 +46,16 @@ var loginCmd = &cobra.Command{
return err
}
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
PreSharedKey: &preSharedKey,
})
}
if preSharedKey != "" {
ic.PreSharedKey = &preSharedKey
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
@@ -106,7 +111,7 @@ var loginCmd = &cobra.Command{
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete)
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil {
@@ -185,7 +190,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return nil, fmt.Errorf("getting a request device code failed: %v", err)
}
openURL(cmd, flowInfo.VerificationURIComplete)
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn)
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
@@ -199,11 +204,16 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return &tokenInfo, nil
}
func openURL(cmd *cobra.Command, verificationURIComplete string) {
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string
if !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
err := open.Run(verificationURIComplete)
cmd.Printf("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
" " + verificationURIComplete + " \n\n")
" " + verificationURIComplete + " " + codeMsg + " \n\n")
if err != nil {
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
}

View File

@@ -78,14 +78,18 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err
}
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
PreSharedKey: &preSharedKey,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
})
}
if preSharedKey != "" {
ic.PreSharedKey = &preSharedKey
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
@@ -172,7 +176,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete)
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil {

View File

@@ -3,6 +3,7 @@ package internal
import (
"context"
"fmt"
"io"
"math/rand"
"net"
"net/netip"
@@ -18,14 +19,15 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/sharedsock"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util"
@@ -99,10 +101,8 @@ type Engine struct {
wgInterface *iface.WGIface
udpMux ice.UDPMux
udpMuxSrflx ice.UniversalUDPMux
udpMuxConn *net.UDPConn
udpMuxConnSrflx *net.UDPConn
udpMux *bind.UniversalUDPMuxDefault
udpMuxConn io.Closer
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
@@ -206,33 +206,17 @@ func (e *Engine) Start() error {
e.close()
return err
}
e.udpMux = udpMux.UDPMuxDefault
e.udpMuxSrflx = udpMux
e.udpMux = udpMux
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
} else {
networkName := "udp"
if e.config.DisableIPv6Discovery {
networkName = "udp4"
}
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewIncomingSTUNFilter())
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
e.close()
return err
}
udpMuxParams := ice.UDPMuxParams{
UDPConn: e.udpMuxConn,
Net: transportNet,
}
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
e.close()
return err
}
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
mux := bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: rawSock, Net: transportNet})
go mux.ReadFromConn(e.ctx)
e.udpMuxConn = rawSock
e.udpMux = mux
}
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
@@ -262,7 +246,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok {
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") {
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p)
continue
}
@@ -394,9 +378,6 @@ func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKe
return err
}
// indicates message support in gRPC
msg.Body.FeaturesSupported = []uint32{signal.DirectCheck}
err = s.Send(msg)
if err != nil {
return err
@@ -776,9 +757,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
// we might have received new STUN and TURN servers meanwhile, so update them
e.syncMsgMux.Lock()
conf := conn.GetConf()
conf.StunTurn = append(e.STUNs, e.TURNs...)
conn.UpdateConf(conf)
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
e.syncMsgMux.Unlock()
err := conn.Open()
@@ -807,9 +786,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
proxyConfig := proxy.Config{
wgConfig := peer.WgConfig{
RemoteKey: pubKey,
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
WgListenPort: e.config.WgPort,
WgInterface: e.wgInterface,
AllowedIps: allowedIPs,
PreSharedKey: e.config.PreSharedKey,
@@ -824,9 +803,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
Timeout: timeout,
UDPMux: e.udpMux,
UDPMuxSrflx: e.udpMuxSrflx,
ProxyConfig: proxyConfig,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(),
@@ -918,18 +897,6 @@ func (e *Engine) receiveSignalEvents() {
}
conn.OnRemoteCandidate(candidate)
case sProto.Body_MODE:
protoMode := msg.GetBody().GetMode()
if protoMode == nil {
return fmt.Errorf("received an empty mode message")
}
err := conn.OnModeMessage(peer.ModeMessage{
Direct: protoMode.GetDirect(),
})
if err != nil {
log.Errorf("failed processing a mode message -> %s", err)
return err
}
}
return nil
@@ -1020,12 +987,6 @@ func (e *Engine) close() {
}
}
if e.udpMuxConnSrflx != nil {
if err := e.udpMuxConnSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux connection: %v", err)
}
}
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {

View File

@@ -148,6 +148,11 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
}
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
if deviceCode.VerificationURIComplete == "" {
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
return deviceCode, err
}

View File

@@ -12,10 +12,11 @@ import (
"github.com/pion/ice/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/version"
@@ -27,8 +28,18 @@ const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
defaultWgKeepAlive = 25 * time.Second
)
type WgConfig struct {
WgListenPort int
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
// ConnConfig is a peer Connection configuration
type ConnConfig struct {
@@ -47,7 +58,7 @@ type ConnConfig struct {
Timeout time.Duration
ProxyConfig proxy.Config
WgConfig WgConfig
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
@@ -102,7 +113,7 @@ type Conn struct {
statusRecorder *Status
proxy proxy.Proxy
proxy *WireGuardProxy
remoteModeCh chan ModeMessage
meta meta
@@ -126,9 +137,14 @@ func (conn *Conn) GetConf() ConnConfig {
return conn.config
}
// UpdateConf updates the connection config
func (conn *Conn) UpdateConf(conf ConnConfig) {
conn.config = conf
// WgConfig returns the WireGuard config
func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig
}
// UpdateStunTurn update the turn and stun addresses
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
conn.config.StunTurn = turnStun
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -239,12 +255,12 @@ func readICEAgentConfigProperties() (time.Duration, time.Duration) {
func (conn *Conn) Open() error {
log.Debugf("trying to connect to peer %s", conn.config.Key)
peerState := State{PubKey: conn.config.Key}
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
peerState.ConnStatusUpdate = time.Now()
peerState.ConnStatus = conn.status
peerState := State{
PubKey: conn.config.Key,
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
ConnStatus: conn.status,
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
@@ -299,10 +315,11 @@ func (conn *Conn) Open() error {
defer conn.notifyDisconnected()
conn.mu.Unlock()
peerState = State{PubKey: conn.config.Key}
peerState.ConnStatus = conn.status
peerState.ConnStatusUpdate = time.Now()
peerState = State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
@@ -333,19 +350,12 @@ func (conn *Conn) Open() error {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
// the ice connection has been established successfully so we are ready to start the proxy
err = conn.startProxy(remoteConn, remoteWgPort)
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
if err != nil {
return err
}
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
} else {
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
}
log.Infof("connected to peer %s, proxy: %v, remote address: %s", conn.config.Key, conn.proxy != nil, remoteAddr.String())
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select {
@@ -358,182 +368,81 @@ func (conn *Conn) Open() error {
}
}
// useProxy determines whether a direct connection (without a go proxy) is possible
//
// There are 3 cases:
//
// * When neither candidate is from hard nat and one of the peers has a public IP
//
// * both peers are in the same private network
//
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
//
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
if !isRelayCandidate(pair.Local) && userspaceBind {
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
return false
}
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
return false
}
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
log.Debugf("shouldn't use proxy because the remote peer is not behind a hard NAT and the local one has a public IP")
return false
}
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network")
return false
}
if (isPeerReflexiveCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) ||
isHostCandidateWithPrivateIP(pair.Local) && isPeerReflexiveCandidateWithPrivateIP(pair.Remote)) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network and one peer is peer reflexive")
return false
}
return true
}
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
localIP := net.ParseIP(pair.Local.Address())
remoteIP := net.ParseIP(pair.Remote.Address())
if localIP == nil || remoteIP == nil {
return false
}
// only consider /16 networks
mask := net.IPMask{255, 255, 0, 0}
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func isHardNATCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
}
func isHostCandidateWithPublicIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeHost && isPublicIP(candidate.Address())
}
func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
}
func isPeerReflexiveCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypePeerReflexive && !isPublicIP(candidate.Address())
}
func isPublicIP(address string) bool {
ip := net.ParseIP(address)
if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
return false
}
return true
}
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
conn.mu.Lock()
defer conn.mu.Unlock()
var pair *ice.CandidatePair
pair, err := conn.agent.GetSelectedCandidatePair()
if err != nil {
return err
return nil, err
}
peerState := State{PubKey: conn.config.Key}
p := conn.getProxyWithMessageExchange(pair, remoteWgPort)
conn.proxy = p
err = p.Start(remoteConn)
var endpoint net.Addr
if isRelayCandidate(pair.Local) {
conn.proxy = NewWireGuardProxy(conn.config.WgConfig.WgListenPort, conn.config.WgConfig.RemoteKey, remoteConn)
endpoint, err = conn.proxy.Start()
if err != nil {
conn.proxy = nil
return nil, err
}
} else {
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
endpoint = remoteConn.RemoteAddr()
}
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpoint, conn.config.WgConfig.PreSharedKey)
if err != nil {
return err
if conn.proxy != nil {
_ = conn.proxy.Close()
}
return nil, err
}
conn.status = StatusConnected
peerState.ConnStatus = conn.status
peerState.ConnStatusUpdate = time.Now()
peerState.LocalIceCandidateType = pair.Local.Type().String()
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
Direct: conn.proxy == nil,
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err)
}
return nil
return endpoint, nil
}
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
localDirectMode := !useProxy
remoteDirectMode := localDirectMode
if conn.meta.protoSupport.DirectCheck {
go conn.sendLocalDirectMode(localDirectMode)
// will block until message received or timeout
remoteDirectMode = conn.receiveRemoteDirectMode()
}
if conn.config.UserspaceBind && localDirectMode {
return proxy.NewNoProxy(conn.config.ProxyConfig)
}
if localDirectMode && remoteDirectMode {
return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
}
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
}
func (conn *Conn) sendLocalDirectMode(localMode bool) {
// todo what happens when we couldn't deliver this message?
// we could retry, etc but there is no guarantee
err := conn.sendSignalMessage(&sProto.Message{
Key: conn.config.LocalKey,
RemoteKey: conn.config.Key,
Body: &sProto.Body{
Type: sProto.Body_MODE,
Mode: &sProto.Mode{
Direct: &localMode,
},
NetBirdVersion: version.NetbirdVersion(),
},
})
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
if err != nil {
log.Errorf("failed to send local proxy mode to remote peer %s, error: %s", conn.config.Key, err)
log.Warnf("got an error while resolving the udp address, err: %s", err)
return
}
}
func (conn *Conn) receiveRemoteDirectMode() bool {
timeout := time.Second
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case receivedMSG := <-conn.remoteModeCh:
return receivedMSG.Direct
case <-timer.C:
// we didn't receive a message from remote so we assume that it supports the direct mode to keep the old behaviour
log.Debugf("timeout after %s while waiting for remote direct mode message from remote peer %s",
timeout, conn.config.Key)
return true
mux, ok := conn.config.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
if !ok {
log.Warn("invalid udp mux conversion")
return
}
_, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
if err != nil {
log.Warnf("got an error while sending the punch packet, err: %s", err)
}
}
@@ -543,20 +452,22 @@ func (conn *Conn) cleanup() error {
conn.mu.Lock()
defer conn.mu.Unlock()
var err1, err2, err3 error
if conn.agent != nil {
err := conn.agent.Close()
if err != nil {
return err
err1 = conn.agent.Close()
if err1 == nil {
conn.agent = nil
}
conn.agent = nil
}
// todo: is it problem if we try to remove a peer what is never existed?
err2 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if conn.proxy != nil {
err := conn.proxy.Close()
if err != nil {
return err
err3 = conn.proxy.Close()
if err3 != nil {
conn.proxy = nil
}
conn.proxy = nil
}
if conn.notifyDisconnected != nil {
@@ -566,10 +477,11 @@ func (conn *Conn) cleanup() error {
conn.status = StatusDisconnected
peerState := State{PubKey: conn.config.Key}
peerState.ConnStatus = conn.status
peerState.ConnStatusUpdate = time.Now()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available.
@@ -578,8 +490,13 @@ func (conn *Conn) cleanup() error {
}
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
return nil
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
}
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
@@ -757,16 +674,6 @@ func (conn *Conn) GetKey() string {
return conn.config.Key
}
// OnModeMessage unmarshall the payload message and send it to the mode message channel
func (conn *Conn) OnModeMessage(message ModeMessage) error {
select {
case conn.remoteModeCh <- message:
return nil
default:
return fmt.Errorf("unable to process mode message: channel busy")
}
}
// RegisterProtoSupportMeta register supported proto message in the connection metadata
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
protoSupport := signal.ParseFeaturesSupported(support)

View File

@@ -9,11 +9,9 @@ import (
"github.com/magiconair/properties/assert"
"github.com/pion/ice/v2"
"golang.org/x/sync/errgroup"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/iface"
sproto "github.com/netbirdio/netbird/signal/proto"
)
var connConf = ConnConfig{
@@ -170,310 +168,3 @@ func TestConn_Close(t *testing.T) {
wg.Wait()
}
type mockICECandidate struct {
ice.Candidate
AddressFunc func() string
TypeFunc func() ice.CandidateType
}
// Address mocks and overwrite ice.Candidate Address method
func (m *mockICECandidate) Address() string {
if m.AddressFunc != nil {
return m.AddressFunc()
}
return ""
}
// Type mocks and overwrite ice.Candidate Type method
func (m *mockICECandidate) Type() ice.CandidateType {
if m.TypeFunc != nil {
return m.TypeFunc()
}
return ice.CandidateTypeUnspecified
}
func TestConn_ShouldUseProxy(t *testing.T) {
publicHostCandidate := &mockICECandidate{
AddressFunc: func() string {
return "8.8.8.8"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
},
}
privateHostCandidate := &mockICECandidate{
AddressFunc: func() string {
return "10.0.0.1"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
},
}
srflxCandidate := &mockICECandidate{
AddressFunc: func() string {
return "1.1.1.1"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeServerReflexive
},
}
prflxCandidate := &mockICECandidate{
AddressFunc: func() string {
return "1.1.1.1"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
},
}
relayCandidate := &mockICECandidate{
AddressFunc: func() string {
return "1.1.1.1"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeRelay
},
}
testCases := []struct {
name string
candatePair *ice.CandidatePair
expected bool
}{
{
name: "Use Proxy When Local Candidate Is Relay",
candatePair: &ice.CandidatePair{
Local: relayCandidate,
Remote: privateHostCandidate,
},
expected: true,
},
{
name: "Use Proxy When Remote Candidate Is Relay",
candatePair: &ice.CandidatePair{
Local: privateHostCandidate,
Remote: relayCandidate,
},
expected: true,
},
{
name: "Use Proxy When Local Candidate Is Peer Reflexive",
candatePair: &ice.CandidatePair{
Local: prflxCandidate,
Remote: privateHostCandidate,
},
expected: true,
},
{
name: "Use Proxy When Remote Candidate Is Peer Reflexive",
candatePair: &ice.CandidatePair{
Local: privateHostCandidate,
Remote: prflxCandidate,
},
expected: true,
},
{
name: "Don't Use Proxy When Local Candidate Is Public And Remote Is Private",
candatePair: &ice.CandidatePair{
Local: publicHostCandidate,
Remote: privateHostCandidate,
},
expected: false,
},
{
name: "Don't Use Proxy When Remote Candidate Is Public And Local Is Private",
candatePair: &ice.CandidatePair{
Local: privateHostCandidate,
Remote: publicHostCandidate,
},
expected: false,
},
{
name: "Don't Use Proxy When Local Candidate is Public And Remote Is Server Reflexive",
candatePair: &ice.CandidatePair{
Local: publicHostCandidate,
Remote: srflxCandidate,
},
expected: false,
},
{
name: "Don't Use Proxy When Remote Candidate is Public And Local Is Server Reflexive",
candatePair: &ice.CandidatePair{
Local: srflxCandidate,
Remote: publicHostCandidate,
},
expected: false,
},
{
name: "Don't Use Proxy When Both Candidates Are Public",
candatePair: &ice.CandidatePair{
Local: publicHostCandidate,
Remote: publicHostCandidate,
},
expected: false,
},
{
name: "Don't Use Proxy When Both Candidates Are Private",
candatePair: &ice.CandidatePair{
Local: privateHostCandidate,
Remote: privateHostCandidate,
},
expected: false,
},
{
name: "Don't Use Proxy When Both Candidates are in private network and one is peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: false,
},
{
name: "Should Use Proxy When Both Candidates are in private network and both are peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := shouldUseProxy(testCase.candatePair, false)
if result != testCase.expected {
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
}
})
}
}
func TestGetProxyWithMessageExchange(t *testing.T) {
publicHostCandidate := &mockICECandidate{
AddressFunc: func() string {
return "8.8.8.8"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
},
}
relayCandidate := &mockICECandidate{
AddressFunc: func() string {
return "1.1.1.1"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeRelay
},
}
testCases := []struct {
name string
candatePair *ice.CandidatePair
inputDirectModeSupport bool
inputRemoteModeMessage bool
expected proxy.Type
}{
{
name: "Should Result In Using Wireguard Proxy When Local Eval Is Use Proxy",
candatePair: &ice.CandidatePair{
Local: relayCandidate,
Remote: publicHostCandidate,
},
inputDirectModeSupport: true,
inputRemoteModeMessage: true,
expected: proxy.TypeWireGuard,
},
{
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
candatePair: &ice.CandidatePair{
Local: publicHostCandidate,
Remote: publicHostCandidate,
},
inputDirectModeSupport: true,
inputRemoteModeMessage: false,
expected: proxy.TypeWireGuard,
},
{
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
candatePair: &ice.CandidatePair{
Local: relayCandidate,
Remote: publicHostCandidate,
},
inputDirectModeSupport: false,
inputRemoteModeMessage: false,
expected: proxy.TypeWireGuard,
},
{
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
candatePair: &ice.CandidatePair{
Local: publicHostCandidate,
Remote: publicHostCandidate,
},
inputDirectModeSupport: false,
inputRemoteModeMessage: false,
expected: proxy.TypeDirectNoProxy,
},
{
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
candatePair: &ice.CandidatePair{
Local: publicHostCandidate,
Remote: publicHostCandidate,
},
inputDirectModeSupport: true,
inputRemoteModeMessage: true,
expected: proxy.TypeDirectNoProxy,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
g := errgroup.Group{}
conn, err := NewConn(connConf, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
conn.meta.protoSupport.DirectCheck = testCase.inputDirectModeSupport
conn.SetSendSignalMessage(func(message *sproto.Message) error {
return nil
})
g.Go(func() error {
return conn.OnModeMessage(ModeMessage{
Direct: testCase.inputRemoteModeMessage,
})
})
resultProxy := conn.getProxyWithMessageExchange(testCase.candatePair, 1000)
err = g.Wait()
if err != nil {
t.Error(err)
}
if resultProxy.Type() != testCase.expected {
t.Errorf("result didn't match expected value: Expected: %s, Got: %s", testCase.expected, resultProxy.Type())
}
})
}
}

View File

@@ -1,7 +1,8 @@
package proxy
package peer
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"net"
)
@@ -11,67 +12,45 @@ type WireGuardProxy struct {
ctx context.Context
cancel context.CancelFunc
config Config
wgListenPort int
remoteKey string
remoteConn net.Conn
localConn net.Conn
}
func NewWireGuardProxy(config Config) *WireGuardProxy {
p := &WireGuardProxy{config: config}
func NewWireGuardProxy(wgListenPort int, remoteKey string, remoteConn net.Conn) *WireGuardProxy {
p := &WireGuardProxy{
wgListenPort: wgListenPort,
remoteKey: remoteKey,
remoteConn: remoteConn,
}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *WireGuardProxy) updateEndpoint() error {
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
if err != nil {
return err
}
// add local proxy connection as a Wireguard peer
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
udpAddr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
p.remoteConn = remoteConn
var err error
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
func (p *WireGuardProxy) Start() (net.Addr, error) {
lConn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", p.wgListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return err
}
err = p.updateEndpoint()
if err != nil {
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
return err
return nil, err
}
p.localConn = lConn
go p.proxyToRemote()
go p.proxyToLocal()
return nil
return lConn.LocalAddr(), nil
}
func (p *WireGuardProxy) Close() error {
p.cancel()
if c := p.localConn; c != nil {
if p.localConn != nil {
err := p.localConn.Close()
if err != nil {
return err
}
}
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
@@ -83,7 +62,7 @@ func (p *WireGuardProxy) proxyToRemote() {
for {
select {
case <-p.ctx.Done():
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey)
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.remoteKey)
return
default:
n, err := p.localConn.Read(buf)
@@ -107,7 +86,7 @@ func (p *WireGuardProxy) proxyToLocal() {
for {
select {
case <-p.ctx.Done():
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey)
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.remoteKey)
return
default:
n, err := p.remoteConn.Read(buf)
@@ -122,7 +101,3 @@ func (p *WireGuardProxy) proxyToLocal() {
}
}
}
func (p *WireGuardProxy) Type() Type {
return TypeWireGuard
}

View File

@@ -1,57 +0,0 @@
package proxy
import (
log "github.com/sirupsen/logrus"
"net"
)
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
// This is possible in either of these cases:
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
type DirectNoProxy struct {
config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
}
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
func NewDirectNoProxy(config Config, remoteWgPort int) *DirectNoProxy {
return &DirectNoProxy{config: config, RemoteWgListenPort: remoteWgPort}
}
// Close removes peer from the WireGuard interface
func (p *DirectNoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote IP and default WireGuard port
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using DirectNoProxy while connecting to peer %s", p.config.RemoteKey)
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
// Type returns the type of this proxy
func (p *DirectNoProxy) Type() Type {
return TypeDirectNoProxy
}

View File

@@ -1,72 +0,0 @@
package proxy
import (
"context"
log "github.com/sirupsen/logrus"
"net"
"time"
)
// DummyProxy just sends pings to the RemoteKey peer and reads responses
type DummyProxy struct {
conn net.Conn
remote string
ctx context.Context
cancel context.CancelFunc
}
func NewDummyProxy(remote string) *DummyProxy {
p := &DummyProxy{remote: remote}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *DummyProxy) Close() error {
p.cancel()
return nil
}
func (p *DummyProxy) Start(remoteConn net.Conn) error {
p.conn = remoteConn
go func() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
_, err := p.conn.Read(buf)
if err != nil {
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
return
}
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
}
}
}()
go func() {
for {
select {
case <-p.ctx.Done():
return
default:
_, err := p.conn.Write([]byte("hello"))
//log.Debugf("sent ping to %s", p.remote)
if err != nil {
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
return
}
time.Sleep(5 * time.Second)
}
}
}()
return nil
}
func (p *DummyProxy) Type() Type {
return TypeDummy
}

View File

@@ -1,42 +0,0 @@
package proxy
import (
log "github.com/sirupsen/logrus"
"net"
)
// NoProxy is used just to configure WireGuard without any local proxy in between.
// Used when the WireGuard interface is userspace and uses bind.ICEBind
type NoProxy struct {
config Config
}
// NewNoProxy creates a new NoProxy with a provided config
func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config}
}
// Close removes peer from the WireGuard interface
func (p *NoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote address
func (p *NoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
}
func (p *NoProxy) Type() Type {
return TypeNoProxy
}

View File

@@ -1,35 +0,0 @@
package proxy
import (
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"io"
"net"
"time"
)
const DefaultWgKeepAlive = 25 * time.Second
type Type string
const (
TypeDirectNoProxy Type = "DirectNoProxy"
TypeWireGuard Type = "WireGuard"
TypeDummy Type = "Dummy"
TypeNoProxy Type = "NoProxy"
)
type Config struct {
WgListenAddr string
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
type Proxy interface {
io.Closer
// Start creates a local remoteConn and starts proxying data from/to remoteConn
Start(remoteConn net.Conn) error
Type() Type
}

8
go.mod
View File

@@ -38,12 +38,14 @@ require (
github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0
github.com/google/go-cmp v0.5.9
github.com/google/gopacket v1.1.19
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
github.com/libp2p/go-netroute v0.2.0
github.com/magiconair/properties v1.8.5
github.com/mattn/go-sqlite3 v1.14.16
github.com/mdlayher/socket v0.4.0
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/open-policy-agent/opa v0.49.0
@@ -93,14 +95,12 @@ require (
github.com/go-stack/stack v1.8.0 // indirect
github.com/gobwas/glob v0.2.3 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/google/gopacket v1.1.19 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
github.com/josharian/native v1.0.0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mdlayher/genetlink v1.1.0 // indirect
github.com/mdlayher/netlink v1.4.2 // indirect
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb // indirect
github.com/mdlayher/netlink v1.7.1 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect

11
go.sum
View File

@@ -380,8 +380,9 @@ github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf
github.com/jackmordaunt/icns v0.0.0-20181231085925-4f16af745526/go.mod h1:UQkeMHVoNcyXYq9otUupF7/h/2tmHlhrS2zw7ZVvUqc=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/josephspurrier/goversioninfo v0.0.0-20200309025242-14b0ab84c6ca/go.mod h1:eJTEwMjXb7kZ633hO3Ln9mBUCOjX2+FlTljvpl9SYdE=
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 h1:uhL5Gw7BINiiPAo24A2sxkcDI0Jt/sqp1v5xQCniEFA=
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/josharian/native v1.0.0 h1:Ts/E8zCSEsG17dUqv7joXJFybuMLjQfWE04tsBODTxk=
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
@@ -391,7 +392,6 @@ github.com/jsimonetti/rtnetlink v0.0.0-20201220180245-69540ac93943/go.mod h1:z4c
github.com/jsimonetti/rtnetlink v0.0.0-20210122163228-8d122574c736/go.mod h1:ZXpIyOK59ZnN7J0BV99cZUPmsqDRZ3eq5X+st7u/oSA=
github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U=
github.com/jsimonetti/rtnetlink v0.0.0-20210525051524-4cc836578190/go.mod h1:NmKSdU4VGSiv1bMsdqNALI4RSvvjtz65tTMCnD05qLo=
github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786 h1:N527AHMa793TP5z5GNAn/VLPzlc0ewzWdeP/25gDfgQ=
github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4hqbTdfQngbVSZJVWUhGE/lbTFf9jb+ygmNUDQMuOs=
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
@@ -441,7 +441,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo=
github.com/mdlayher/ethtool v0.0.0-20211028163843-288d040e9d60 h1:tHdB+hQRHU10CfcK0furo6rSNgZ38JT8uPh70c/pFD8=
github.com/mdlayher/ethtool v0.0.0-20211028163843-288d040e9d60/go.mod h1:aYbhishWc4Ai3I2U4Gaa2n3kHWSwzme6EsG/46HRQbE=
github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc=
github.com/mdlayher/genetlink v1.1.0 h1:k2YQT3959rJOF7gOvhdfQ0lut7QMIZiuVlJANheoZ+E=
@@ -456,12 +455,14 @@ github.com/mdlayher/netlink v1.2.2-0.20210123213345-5cc92139ae3e/go.mod h1:bacnN
github.com/mdlayher/netlink v1.3.0/go.mod h1:xK/BssKuwcRXHrtN04UBkwQ6dY9VviGGuriDdoPSWys=
github.com/mdlayher/netlink v1.4.0/go.mod h1:dRJi5IABcZpBD2A3D0Mv/AiX8I9uDEu5oGkAVrekmf8=
github.com/mdlayher/netlink v1.4.1/go.mod h1:e4/KuJ+s8UhfUpO9z00/fDZZmhSrs+oxyqAS9cNgn6Q=
github.com/mdlayher/netlink v1.4.2 h1:3sbnJWe/LETovA7yRZIX3f9McVOWV3OySH6iIBxiFfI=
github.com/mdlayher/netlink v1.4.2/go.mod h1:13VaingaArGUTUxFLf/iEovKxXji32JAtF858jZYEug=
github.com/mdlayher/netlink v1.7.1 h1:FdUaT/e33HjEXagwELR8R3/KL1Fq5x3G5jgHLp/BTmg=
github.com/mdlayher/netlink v1.7.1/go.mod h1:nKO5CSjE/DJjVhk/TNp6vCE1ktVxEA8VEh8drhZzxsQ=
github.com/mdlayher/socket v0.0.0-20210307095302-262dc9984e00/go.mod h1:GAFlyu4/XV68LkQKYzKhIo/WW7j3Zi0YRAz/BOoanUc=
github.com/mdlayher/socket v0.0.0-20211007213009-516dcbdf0267/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMVzFJ00qM25lqESg9Z4u3GuEXN5iHY=
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw=
github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc=
github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg=
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=

View File

@@ -5,6 +5,7 @@ package bind
*/
import (
"context"
"fmt"
"net"
"time"
@@ -68,6 +69,39 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
return m
}
// ReadFromConn reads from the m.params.UDPConn provided upon the creation. It expects STUN packets only, however, will
// just ignore other packets printing an warning message.
// It is a blocking method, consider running in a go routine.
func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
buf := make([]byte, 1500)
for {
select {
case <-ctx.Done():
log.Debugf("stopped reading from the UDPConn due to finished context")
return
default:
_, a, err := m.params.UDPConn.ReadFrom(buf)
if err != nil {
log.Errorf("error while reading packet %s", err)
continue
}
msg := &stun.Message{
Raw: buf,
}
err = msg.Decode()
if err != nil {
log.Warnf("error while parsing STUN message. The packet doesn't seem to be a STUN packet: %s", err)
continue
}
err = m.HandleSTUNMessage(msg, a)
if err != nil {
log.Errorf("error while handling STUn message: %s", err)
}
}
}
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
@@ -75,6 +109,11 @@ type udpConn struct {
logger logging.LeveledLogger
}
// GetSharedConn returns the shared udp conn
func (m *UniversalUDPMuxDefault) GetSharedConn() net.PacketConn {
return m.params.UDPConn
}
// GetListenAddresses returns the listen addr of this UDP
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}

View File

@@ -77,12 +77,17 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint net.Addr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
rAddr, err := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
if err != nil {
return err
}
log.Debugf("updating interface %s peer %s, endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, rAddr, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface

View File

@@ -202,7 +202,7 @@ func (c *wGConfigurer) configure(config wgtypes.Config) error {
if err != nil {
return err
}
log.Debugf("got Wireguard device %s", c.deviceName)
log.Tracef("got Wireguard device %s", c.deviceName)
return wg.ConfigureDevice(c.deviceName, config)
}

View File

@@ -80,6 +80,7 @@ var (
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
}
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
tlsEnabled := false
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
@@ -186,6 +187,7 @@ var (
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err)
@@ -394,6 +396,12 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
}
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.Infof("configuring IdpManagerConfig.OIDCConfig.Issuer with a new value %s,", oidcConfig.Issuer)
config.IdpManagerConfig.OIDCConfig.Issuer = strings.TrimRight(oidcConfig.Issuer, "/")
log.Infof("configuring IdpManagerConfig.OIDCConfig.TokenEndpoint with a new value %s,", oidcConfig.TokenEndpoint)
config.IdpManagerConfig.OIDCConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, config.HttpConfig.AuthIssuer)
config.HttpConfig.AuthIssuer = oidcConfig.Issuer
@@ -439,7 +447,7 @@ type OIDCConfigResponse struct {
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
res, err := http.Get(oidcEndpoint)
if err != nil {
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err)
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
}
defer func() {

View File

@@ -16,13 +16,14 @@ const (
)
var (
dnsDomain string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
disableSingleAccMode bool
dnsDomain string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
disableSingleAccMode bool
idpSignKeyRefreshEnabled bool
rootCmd = &cobra.Command{
Use: "netbird-mgmt",
@@ -54,6 +55,7 @@ func init() {
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -49,17 +49,16 @@ type AccountManager interface {
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, executingUserID string, targetUserID string) error
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountByUserID(userID string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error
IsUserAdmin(userID string) (bool, error)
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error)
@@ -70,10 +69,10 @@ type AccountManager interface {
GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(peerID string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
@@ -180,6 +179,7 @@ type UserInfo struct {
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -903,7 +903,9 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users {
users[user.Id] = struct{}{}
if !user.IsServiceUser {
users[user.Id] = struct{}{}
}
}
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id)

View File

@@ -91,6 +91,10 @@ const (
ServiceUserCreated
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted
// UserBlocked indicates that a user blocked another user
UserBlocked
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
)
const (
@@ -184,6 +188,10 @@ const (
ServiceUserCreatedMessage string = "Service user created"
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
ServiceUserDeletedMessage string = "Service user deleted"
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
UserBlockedMessage string = "User blocked"
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
UserUnblockedMessage string = "User unblocked"
)
// Activity that triggered an Event
@@ -282,6 +290,10 @@ func (a Activity) Message() string {
return ServiceUserCreatedMessage
case ServiceUserDeleted:
return ServiceUserDeletedMessage
case UserBlocked:
return UserBlockedMessage
case UserUnblocked:
return UserUnblockedMessage
default:
return "UNKNOWN_ACTIVITY"
}
@@ -300,6 +312,10 @@ func (a Activity) StringCode() string {
return "user.join"
case UserInvited:
return "user.invite"
case UserBlocked:
return "user.block"
case UserUnblocked:
return "user.unblock"
case AccountCreated:
return "account.create"
case RuleAdded:

View File

@@ -80,6 +80,8 @@ type HttpServerConfig struct {
AuthKeysLocation string
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
OIDCConfigEndpoint string
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
IdpSignKeyRefreshEnabled bool
}
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)

View File

@@ -52,7 +52,9 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation)
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}

View File

@@ -59,13 +59,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
}
vars := mux.Vars(r)
accountID := vars["id"]
accountID := vars["accountId"]
if len(accountID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
return
}
var req api.PutApiAccountsIdJSONBody
var req api.PutApiAccountsAccountIdJSONBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)

View File

@@ -136,7 +136,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET")
router.HandleFunc("/api/accounts/{id}", handler.UpdateAccount).Methods("PUT")
router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,8 +6,6 @@ info:
tags:
- name: Users
description: Interact with and view information about users.
- name: Tokens
description: Interact with and view information about tokens.
- name: Peers
description: Interact with and view information about peers.
- name: Setup Keys
@@ -67,7 +65,7 @@ components:
status:
description: User's status
type: string
enum: [ "active","invited","disabled" ]
enum: [ "active","invited","blocked" ]
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
@@ -81,6 +79,9 @@ components:
description: Is true if this user is a service user
type: boolean
readOnly: true
is_blocked:
description: Is true if this user is blocked. Blocked users can't use the system
type: boolean
required:
- id
- email
@@ -88,6 +89,7 @@ components:
- role
- auto_groups
- status
- is_blocked
UserRequest:
type: object
properties:
@@ -99,21 +101,25 @@ components:
type: array
items:
type: string
is_blocked:
description: If set to true then user is blocked and can't use the system
type: boolean
required:
- role
- auto_groups
- is_blocked
UserCreateRequest:
type: object
properties:
role:
description: User's NetBird account role
type: string
email:
description: User's Email to send invite to
type: string
name:
description: User's full name
type: string
role:
description: User's NetBird account role
type: string
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
@@ -343,6 +349,8 @@ components:
expires_in:
description: Expiration in days
type: integer
minimum: 1
maximum: 365
required:
- name
- expires_in
@@ -374,33 +382,6 @@ components:
$ref: '#/components/schemas/PeerMinimum'
required:
- peers
PatchMinimum:
type: object
properties:
op:
description: Patch operation type
type: string
enum: [ "replace","add","remove" ]
value:
description: Values to be applied
type: array
items:
type: string
required:
- op
- value
GroupPatchOperation:
allOf:
- $ref: '#/components/schemas/PatchMinimum'
- type: object
properties:
path:
description: Group field to update in form /<field>
type: string
enum: [ "name","peers" ]
required:
- path
RuleMinimum:
type: object
properties:
@@ -446,17 +427,6 @@ components:
required:
- sources
- destinations
RulePatchOperation:
allOf:
- $ref: '#/components/schemas/PatchMinimum'
- type: object
properties:
path:
description: Rule field to update in form /<field>
type: string
enum: [ "name","description","disabled","flow","sources","destinations" ]
required:
- path
PolicyRule:
type: object
properties:
@@ -585,17 +555,6 @@ components:
- id
- network_type
- $ref: '#/components/schemas/RouteRequest'
RoutePatchOperation:
allOf:
- $ref: '#/components/schemas/PatchMinimum'
- type: object
properties:
path:
description: Route field to update in form /<field>
type: string
enum: [ "network","network_id","description","enabled","peer","metric","masquerade", "groups" ]
required:
- path
Nameserver:
type: object
properties:
@@ -667,17 +626,6 @@ components:
required:
- id
- $ref: '#/components/schemas/NameserverGroupRequest'
NameserverGroupPatchOperation:
allOf:
- $ref: '#/components/schemas/PatchMinimum'
- type: object
properties:
path:
description: Nameserver group field to update in form /<field>
type: string
enum: [ "name", "description", "enabled", "groups", "nameservers", "primary", "domains" ]
required:
- path
DNSSettings:
type: object
properties:
@@ -705,7 +653,7 @@ components:
description: The string code of the activity that occurred during the event
type: string
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
"user.role.update",
"user.role.update", "user.block", "user.unblock",
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
"setupkey.group.delete", "setupkey.group.add",
"rule.add", "rule.delete", "rule.update",
@@ -761,15 +709,23 @@ components:
type: http
scheme: bearer
bearerFormat: JWT
TokenAuth:
type: apiKey
in: header
name: Authorization
description: >-
Enter the token with the `Token` prefix, e.g. "Token nbp_F3f0d.....".
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
paths:
/api/accounts:
get:
summary: Returns a list of accounts of a user. Always returns a list of one account. Only available for admin users.
summary: Returns a list of accounts of a user. Always returns a list of one account.
tags: [ Accounts ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON array of accounts
@@ -787,19 +743,20 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/accounts/{id}:
/api/accounts/{accountId}:
put:
summary: Update information about an account
tags: [ Accounts ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: accountId
required: true
schema:
type: string
description: The Account ID
description: The unique identifier of an account
requestBody:
description: update an account
content:
@@ -832,12 +789,13 @@ paths:
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: query
name: service_user
schema:
type: boolean
description: Filters users and returns either normal users or service users
description: Filters users and returns either regular users or service users
responses:
'200':
description: A JSON array of Users
@@ -855,12 +813,12 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/:
post:
summary: Create a User (invite)
summary: Create a User (or invite)
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: User invite information
content:
@@ -882,19 +840,20 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/{id}:
/api/users/{userId}:
put:
summary: Update information about a User
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: userId
required: true
schema:
type: string
description: The User ID
description: The unique identifier of a user
requestBody:
description: User update
content:
@@ -923,11 +882,11 @@ paths:
- BearerAuth: [ ]
parameters:
- in: path
name: id
name: userId
required: true
schema:
type: string
description: The User ID
description: The unique identifier of a user
responses:
'200':
description: Delete status code
@@ -943,16 +902,17 @@ paths:
/api/users/{userId}/tokens:
get:
summary: Returns a list of all tokens for a user
tags: [ Tokens ]
tags: [ Users ]
security:
- BearerAuth: []
- TokenAuth: [ ]
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The User ID
description: The unique identifier of a user
responses:
'200':
description: A JSON Array of PersonalAccessTokens
@@ -971,17 +931,18 @@ paths:
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a new token
tags: [ Tokens ]
summary: Create a new token for a user
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The User ID
description: The unique identifier of a user
requestBody:
description: PersonalAccessToken create parameters
content:
@@ -1005,23 +966,24 @@ paths:
"$ref": "#/components/responses/internal_error"
/api/users/{userId}/tokens/{tokenId}:
get:
summary: Returns a specific token
tags: [ Tokens ]
summary: Returns a specific token for a user
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The User ID
description: The unique identifier of a user
- in: path
name: tokenId
required: true
schema:
type: string
description: The Token ID
description: The unique identifier of a token
responses:
'200':
description: A PersonalAccessTokens Object
@@ -1038,23 +1000,24 @@ paths:
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a token
tags: [ Tokens ]
summary: Delete a token for a user
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The User ID
description: The unique identifier of a user
- in: path
name: tokenId
required: true
schema:
type: string
description: The Token ID
description: The unique identifier of a token
responses:
'200':
description: Delete status code
@@ -1073,6 +1036,7 @@ paths:
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Peers
@@ -1090,19 +1054,20 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers/{id}:
/api/peers/{peerId}:
get:
summary: Get information about a peer
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: peerId
required: true
schema:
type: string
description: The Peer ID
description: The unique identifier of a peer
responses:
'200':
description: A Peer object
@@ -1123,13 +1088,14 @@ paths:
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: peerId
required: true
schema:
type: string
description: The Peer ID
description: The unique identifier of a peer
requestBody:
description: update a peer
content:
@@ -1167,13 +1133,14 @@ paths:
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: peerId
required: true
schema:
type: string
description: The Peer ID
description: The unique identifier of a peer
responses:
'200':
description: Delete status code
@@ -1192,6 +1159,7 @@ paths:
tags: [ Setup Keys ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Setup keys
@@ -1214,6 +1182,7 @@ paths:
tags: [ Setup Keys ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New Setup Key request
content:
@@ -1235,19 +1204,20 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/setup-keys/{id}:
/api/setup-keys/{keyId}:
get:
summary: Get information about a Setup Key
tags: [ Setup Keys ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: keyId
required: true
schema:
type: string
description: The Setup Key ID
description: The unique identifier of a setup key
responses:
'200':
description: A Setup Key object
@@ -1268,13 +1238,14 @@ paths:
tags: [ Setup Keys ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: keyId
required: true
schema:
type: string
description: The Setup Key ID
description: The unique identifier of a setup key
requestBody:
description: update to Setup Key
content:
@@ -1296,36 +1267,13 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Setup Key
tags: [ Setup Keys ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Setup Key ID
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/groups:
get:
summary: Returns a list of all Groups
tags: [ Groups ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Groups
@@ -1348,6 +1296,7 @@ paths:
tags: [ Groups ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New Group request
content:
@@ -1378,19 +1327,20 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/groups/{id}:
/api/groups/{groupId}:
get:
summary: Get information about a Group
tags: [ Groups ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: groupId
required: true
schema:
type: string
description: The Group ID
description: The unique identifier of a group
responses:
'200':
description: A Group object
@@ -1411,13 +1361,14 @@ paths:
tags: [ Groups ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: groupId
required: true
schema:
type: string
description: The Group ID
description: The unique identifier of a group
requestBody:
description: Update Group request
content:
@@ -1446,53 +1397,19 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
patch:
summary: Update information about a Group
tags: [ Groups ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Group ID
requestBody:
description: Update Group request using a list of json patch objects
content:
'application/json':
schema:
type: array
items:
$ref: '#/components/schemas/GroupPatchOperation'
responses:
'200':
description: A Group object
content:
application/json:
schema:
$ref: '#/components/schemas/Group'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Group
tags: [ Groups ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: groupId
required: true
schema:
type: string
description: The Group ID
description: The unique identifier of a group
responses:
'200':
description: Delete status code
@@ -1511,6 +1428,7 @@ paths:
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Rules
@@ -1533,6 +1451,7 @@ paths:
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New Rule request
content:
@@ -1557,19 +1476,20 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/Rule'
/api/rules/{id}:
/api/rules/{ruleId}:
get:
summary: Get information about a Rules
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: ruleId
required: true
schema:
type: string
description: The Rule ID
description: The unique identifier of a rule
responses:
'200':
description: A Rule object
@@ -1590,13 +1510,14 @@ paths:
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: ruleId
required: true
schema:
type: string
description: The Rule ID
description: The unique identifier of a rule
requestBody:
description: Update Rule request
content:
@@ -1634,13 +1555,14 @@ paths:
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: ruleId
required: true
schema:
type: string
description: The Rule ID
description: The unique identifier of a rule
responses:
'200':
description: Delete status code
@@ -1659,6 +1581,7 @@ paths:
tags: [ Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Policies
@@ -1681,6 +1604,7 @@ paths:
tags: [ Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New Policy request
content:
@@ -1695,19 +1619,20 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/Policy'
/api/policies/{id}:
/api/policies/{policyId}:
get:
summary: Get information about a Policies
tags: [ Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: policyId
required: true
schema:
type: string
description: The Policy ID
description: The unique identifier of a policy
responses:
'200':
description: A Policy object
@@ -1728,13 +1653,14 @@ paths:
tags: [ Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: policyId
required: true
schema:
type: string
description: The Policy ID
description: The unique identifier of a policy
requestBody:
description: Update Policy request
content:
@@ -1762,13 +1688,14 @@ paths:
tags: [ Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: policyId
required: true
schema:
type: string
description: The Policy ID
description: The unique identifier of a policy
responses:
'200':
description: Delete status code
@@ -1787,6 +1714,7 @@ paths:
tags: [ Routes ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Routes
@@ -1809,6 +1737,7 @@ paths:
tags: [ Routes ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New Routes request
content:
@@ -1831,19 +1760,20 @@ paths:
'500':
"$ref": "#/components/responses/internal_error"
/api/routes/{id}:
/api/routes/{routeId}:
get:
summary: Get information about a Routes
tags: [ Routes ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: routeId
required: true
schema:
type: string
description: The Route ID
description: The unique identifier of a route
responses:
'200':
description: A Route object
@@ -1864,13 +1794,14 @@ paths:
tags: [ Routes ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: routeId
required: true
schema:
type: string
description: The Route ID
description: The unique identifier of a route
requestBody:
description: Update Route request
content:
@@ -1892,53 +1823,19 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
patch:
summary: Update information about a Route
tags: [ Routes ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Route ID
requestBody:
description: Update Route request using a list of json patch objects
content:
'application/json':
schema:
type: array
items:
$ref: '#/components/schemas/RoutePatchOperation'
responses:
'200':
description: A Route object
content:
application/json:
schema:
$ref: '#/components/schemas/Route'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Route
tags: [ Routes ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: routeId
required: true
schema:
type: string
description: The Route ID
description: The unique identifier of a route
responses:
'200':
description: Delete status code
@@ -1957,6 +1854,7 @@ paths:
tags: [ DNS ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Nameserver Groups
@@ -1979,6 +1877,7 @@ paths:
tags: [ DNS ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New Nameserver Groups request
content:
@@ -2001,19 +1900,20 @@ paths:
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/nameservers/{id}:
/api/dns/nameservers/{nsgroupId}:
get:
summary: Get information about a Nameserver Groups
tags: [ DNS ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: nsgroupId
required: true
schema:
type: string
description: The Nameserver Group ID
description: The unique identifier of a Nameserver Group
responses:
'200':
description: A Nameserver Group object
@@ -2034,13 +1934,14 @@ paths:
tags: [ DNS ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: nsgroupId
required: true
schema:
type: string
description: The Nameserver Group ID
description: The unique identifier of a Nameserver Group
requestBody:
description: Update Nameserver Group request
content:
@@ -2062,53 +1963,19 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
patch:
summary: Update information about a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
requestBody:
description: Update Nameserver Group request using a list of json patch objects
content:
'application/json':
schema:
type: array
items:
$ref: '#/components/schemas/NameserverGroupPatchOperation'
responses:
'200':
description: A Nameserver Group object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: id
name: nsgroupId
required: true
schema:
type: string
description: The Nameserver Group ID
description: The unique identifier of a Nameserver Group
responses:
'200':
description: Delete status code
@@ -2128,6 +1995,7 @@ paths:
tags: [ DNS ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Object of DNS Setting
@@ -2149,6 +2017,7 @@ paths:
tags: [ DNS ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: A DNS settings object
content:
@@ -2176,6 +2045,7 @@ paths:
tags: [ Events ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Events

View File

@@ -9,6 +9,7 @@ import (
const (
BearerAuthScopes = "BearerAuth.Scopes"
TokenAuthScopes = "TokenAuth.Scopes"
)
// Defines values for EventActivityCode.
@@ -45,6 +46,7 @@ const (
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
EventActivityCodeUserBlock EventActivityCode = "user.block"
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
EventActivityCodeUserInvite EventActivityCode = "user.invite"
@@ -52,19 +54,7 @@ const (
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
)
// Defines values for GroupPatchOperationOp.
const (
GroupPatchOperationOpAdd GroupPatchOperationOp = "add"
GroupPatchOperationOpRemove GroupPatchOperationOp = "remove"
GroupPatchOperationOpReplace GroupPatchOperationOp = "replace"
)
// Defines values for GroupPatchOperationPath.
const (
GroupPatchOperationPathName GroupPatchOperationPath = "name"
GroupPatchOperationPathPeers GroupPatchOperationPath = "peers"
EventActivityCodeUserUnblock EventActivityCode = "user.unblock"
)
// Defines values for NameserverNsType.
@@ -72,61 +62,17 @@ const (
NameserverNsTypeUdp NameserverNsType = "udp"
)
// Defines values for NameserverGroupPatchOperationOp.
const (
NameserverGroupPatchOperationOpAdd NameserverGroupPatchOperationOp = "add"
NameserverGroupPatchOperationOpRemove NameserverGroupPatchOperationOp = "remove"
NameserverGroupPatchOperationOpReplace NameserverGroupPatchOperationOp = "replace"
)
// Defines values for NameserverGroupPatchOperationPath.
const (
NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description"
NameserverGroupPatchOperationPathDomains NameserverGroupPatchOperationPath = "domains"
NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled"
NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups"
NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name"
NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers"
NameserverGroupPatchOperationPathPrimary NameserverGroupPatchOperationPath = "primary"
)
// Defines values for PatchMinimumOp.
const (
PatchMinimumOpAdd PatchMinimumOp = "add"
PatchMinimumOpRemove PatchMinimumOp = "remove"
PatchMinimumOpReplace PatchMinimumOp = "replace"
)
// Defines values for PolicyRuleAction.
const (
PolicyRuleActionAccept PolicyRuleAction = "accept"
PolicyRuleActionDrop PolicyRuleAction = "drop"
)
// Defines values for RoutePatchOperationOp.
const (
RoutePatchOperationOpAdd RoutePatchOperationOp = "add"
RoutePatchOperationOpRemove RoutePatchOperationOp = "remove"
RoutePatchOperationOpReplace RoutePatchOperationOp = "replace"
)
// Defines values for RoutePatchOperationPath.
const (
RoutePatchOperationPathDescription RoutePatchOperationPath = "description"
RoutePatchOperationPathEnabled RoutePatchOperationPath = "enabled"
RoutePatchOperationPathGroups RoutePatchOperationPath = "groups"
RoutePatchOperationPathMasquerade RoutePatchOperationPath = "masquerade"
RoutePatchOperationPathMetric RoutePatchOperationPath = "metric"
RoutePatchOperationPathNetwork RoutePatchOperationPath = "network"
RoutePatchOperationPathNetworkId RoutePatchOperationPath = "network_id"
RoutePatchOperationPathPeer RoutePatchOperationPath = "peer"
)
// Defines values for UserStatus.
const (
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
UserStatusActive UserStatus = "active"
UserStatusBlocked UserStatus = "blocked"
UserStatusInvited UserStatus = "invited"
)
// Account defines model for Account.
@@ -205,24 +151,6 @@ type GroupMinimum struct {
PeersCount int `json:"peers_count"`
}
// GroupPatchOperation defines model for GroupPatchOperation.
type GroupPatchOperation struct {
// Op Patch operation type
Op GroupPatchOperationOp `json:"op"`
// Path Group field to update in form /<field>
Path GroupPatchOperationPath `json:"path"`
// Value Values to be applied
Value []string `json:"value"`
}
// GroupPatchOperationOp Patch operation type
type GroupPatchOperationOp string
// GroupPatchOperationPath Group field to update in form /<field>
type GroupPatchOperationPath string
// Nameserver defines model for Nameserver.
type Nameserver struct {
// Ip Nameserver IP
@@ -265,24 +193,6 @@ type NameserverGroup struct {
Primary bool `json:"primary"`
}
// NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation.
type NameserverGroupPatchOperation struct {
// Op Patch operation type
Op NameserverGroupPatchOperationOp `json:"op"`
// Path Nameserver group field to update in form /<field>
Path NameserverGroupPatchOperationPath `json:"path"`
// Value Values to be applied
Value []string `json:"value"`
}
// NameserverGroupPatchOperationOp Patch operation type
type NameserverGroupPatchOperationOp string
// NameserverGroupPatchOperationPath Nameserver group field to update in form /<field>
type NameserverGroupPatchOperationPath string
// NameserverGroupRequest defines model for NameserverGroupRequest.
type NameserverGroupRequest struct {
// Description Nameserver group description
@@ -307,18 +217,6 @@ type NameserverGroupRequest struct {
Primary bool `json:"primary"`
}
// PatchMinimum defines model for PatchMinimum.
type PatchMinimum struct {
// Op Patch operation type
Op PatchMinimumOp `json:"op"`
// Value Values to be applied
Value []string `json:"value"`
}
// PatchMinimumOp Patch operation type
type PatchMinimumOp string
// Peer defines model for Peer.
type Peer struct {
// Connected Peer to Management connection status
@@ -516,24 +414,6 @@ type Route struct {
Peer string `json:"peer"`
}
// RoutePatchOperation defines model for RoutePatchOperation.
type RoutePatchOperation struct {
// Op Patch operation type
Op RoutePatchOperationOp `json:"op"`
// Path Route field to update in form /<field>
Path RoutePatchOperationPath `json:"path"`
// Value Values to be applied
Value []string `json:"value"`
}
// RoutePatchOperationOp Patch operation type
type RoutePatchOperationOp string
// RoutePatchOperationPath Route field to update in form /<field>
type RoutePatchOperationPath string
// RouteRequest defines model for RouteRequest.
type RouteRequest struct {
// Description Route description
@@ -674,6 +554,9 @@ type User struct {
// Id User ID
Id string `json:"id"`
// IsBlocked Is true if this user is blocked. Blocked users can't use the system
IsBlocked bool `json:"is_blocked"`
// IsCurrent Is true if authenticated user is the same as this user
IsCurrent *bool `json:"is_current,omitempty"`
@@ -716,35 +599,32 @@ type UserRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// IsBlocked If set to true then user is blocked and can't use the system
IsBlocked bool `json:"is_blocked"`
// Role User's NetBird account role
Role string `json:"role"`
}
// PutApiAccountsIdJSONBody defines parameters for PutApiAccountsId.
type PutApiAccountsIdJSONBody struct {
// PutApiAccountsAccountIdJSONBody defines parameters for PutApiAccountsAccountId.
type PutApiAccountsAccountIdJSONBody struct {
Settings AccountSettings `json:"settings"`
}
// PatchApiDnsNameserversIdJSONBody defines parameters for PatchApiDnsNameserversId.
type PatchApiDnsNameserversIdJSONBody = []NameserverGroupPatchOperation
// PostApiGroupsJSONBody defines parameters for PostApiGroups.
type PostApiGroupsJSONBody struct {
Name string `json:"name"`
Peers *[]string `json:"peers,omitempty"`
}
// PatchApiGroupsIdJSONBody defines parameters for PatchApiGroupsId.
type PatchApiGroupsIdJSONBody = []GroupPatchOperation
// PutApiGroupsIdJSONBody defines parameters for PutApiGroupsId.
type PutApiGroupsIdJSONBody struct {
// PutApiGroupsGroupIdJSONBody defines parameters for PutApiGroupsGroupId.
type PutApiGroupsGroupIdJSONBody struct {
Name *string `json:"Name,omitempty"`
Peers *[]string `json:"Peers,omitempty"`
}
// PutApiPeersIdJSONBody defines parameters for PutApiPeersId.
type PutApiPeersIdJSONBody struct {
// PutApiPeersPeerIdJSONBody defines parameters for PutApiPeersPeerId.
type PutApiPeersPeerIdJSONBody struct {
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
@@ -753,11 +633,8 @@ type PutApiPeersIdJSONBody struct {
// PostApiPoliciesJSONBody defines parameters for PostApiPolicies.
type PostApiPoliciesJSONBody = PolicyMinimum
// PutApiPoliciesIdJSONBody defines parameters for PutApiPoliciesId.
type PutApiPoliciesIdJSONBody = PolicyMinimum
// PatchApiRoutesIdJSONBody defines parameters for PatchApiRoutesId.
type PatchApiRoutesIdJSONBody = []RoutePatchOperation
// PutApiPoliciesPolicyIdJSONBody defines parameters for PutApiPoliciesPolicyId.
type PutApiPoliciesPolicyIdJSONBody = PolicyMinimum
// PostApiRulesJSONBody defines parameters for PostApiRules.
type PostApiRulesJSONBody struct {
@@ -776,8 +653,8 @@ type PostApiRulesJSONBody struct {
Sources *[]string `json:"sources,omitempty"`
}
// PutApiRulesIdJSONBody defines parameters for PutApiRulesId.
type PutApiRulesIdJSONBody struct {
// PutApiRulesRuleIdJSONBody defines parameters for PutApiRulesRuleId.
type PutApiRulesRuleIdJSONBody struct {
// Description Rule friendly description
Description string `json:"description"`
Destinations *[]string `json:"destinations,omitempty"`
@@ -795,21 +672,18 @@ type PutApiRulesIdJSONBody struct {
// GetApiUsersParams defines parameters for GetApiUsers.
type GetApiUsersParams struct {
// ServiceUser Filters users and returns either normal users or service users
// ServiceUser Filters users and returns either regular users or service users
ServiceUser *bool `form:"service_user,omitempty" json:"service_user,omitempty"`
}
// PutApiAccountsIdJSONRequestBody defines body for PutApiAccountsId for application/json ContentType.
type PutApiAccountsIdJSONRequestBody PutApiAccountsIdJSONBody
// PutApiAccountsAccountIdJSONRequestBody defines body for PutApiAccountsAccountId for application/json ContentType.
type PutApiAccountsAccountIdJSONRequestBody PutApiAccountsAccountIdJSONBody
// PostApiDnsNameserversJSONRequestBody defines body for PostApiDnsNameservers for application/json ContentType.
type PostApiDnsNameserversJSONRequestBody = NameserverGroupRequest
// PatchApiDnsNameserversIdJSONRequestBody defines body for PatchApiDnsNameserversId for application/json ContentType.
type PatchApiDnsNameserversIdJSONRequestBody = PatchApiDnsNameserversIdJSONBody
// PutApiDnsNameserversIdJSONRequestBody defines body for PutApiDnsNameserversId for application/json ContentType.
type PutApiDnsNameserversIdJSONRequestBody = NameserverGroupRequest
// PutApiDnsNameserversNsgroupIdJSONRequestBody defines body for PutApiDnsNameserversNsgroupId for application/json ContentType.
type PutApiDnsNameserversNsgroupIdJSONRequestBody = NameserverGroupRequest
// PutApiDnsSettingsJSONRequestBody defines body for PutApiDnsSettings for application/json ContentType.
type PutApiDnsSettingsJSONRequestBody = DNSSettings
@@ -817,47 +691,41 @@ type PutApiDnsSettingsJSONRequestBody = DNSSettings
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
// PatchApiGroupsIdJSONRequestBody defines body for PatchApiGroupsId for application/json ContentType.
type PatchApiGroupsIdJSONRequestBody = PatchApiGroupsIdJSONBody
// PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType.
type PutApiGroupsGroupIdJSONRequestBody PutApiGroupsGroupIdJSONBody
// PutApiGroupsIdJSONRequestBody defines body for PutApiGroupsId for application/json ContentType.
type PutApiGroupsIdJSONRequestBody PutApiGroupsIdJSONBody
// PutApiPeersIdJSONRequestBody defines body for PutApiPeersId for application/json ContentType.
type PutApiPeersIdJSONRequestBody PutApiPeersIdJSONBody
// PutApiPeersPeerIdJSONRequestBody defines body for PutApiPeersPeerId for application/json ContentType.
type PutApiPeersPeerIdJSONRequestBody PutApiPeersPeerIdJSONBody
// PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType.
type PostApiPoliciesJSONRequestBody = PostApiPoliciesJSONBody
// PutApiPoliciesIdJSONRequestBody defines body for PutApiPoliciesId for application/json ContentType.
type PutApiPoliciesIdJSONRequestBody = PutApiPoliciesIdJSONBody
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
type PutApiPoliciesPolicyIdJSONRequestBody = PutApiPoliciesPolicyIdJSONBody
// PostApiRoutesJSONRequestBody defines body for PostApiRoutes for application/json ContentType.
type PostApiRoutesJSONRequestBody = RouteRequest
// PatchApiRoutesIdJSONRequestBody defines body for PatchApiRoutesId for application/json ContentType.
type PatchApiRoutesIdJSONRequestBody = PatchApiRoutesIdJSONBody
// PutApiRoutesIdJSONRequestBody defines body for PutApiRoutesId for application/json ContentType.
type PutApiRoutesIdJSONRequestBody = RouteRequest
// PutApiRoutesRouteIdJSONRequestBody defines body for PutApiRoutesRouteId for application/json ContentType.
type PutApiRoutesRouteIdJSONRequestBody = RouteRequest
// PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType.
type PostApiRulesJSONRequestBody PostApiRulesJSONBody
// PutApiRulesIdJSONRequestBody defines body for PutApiRulesId for application/json ContentType.
type PutApiRulesIdJSONRequestBody PutApiRulesIdJSONBody
// PutApiRulesRuleIdJSONRequestBody defines body for PutApiRulesRuleId for application/json ContentType.
type PutApiRulesRuleIdJSONRequestBody PutApiRulesRuleIdJSONBody
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
type PutApiSetupKeysIdJSONRequestBody = SetupKeyRequest
// PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType.
type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
type PostApiUsersJSONRequestBody = UserCreateRequest
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
type PutApiUsersIdJSONRequestBody = UserRequest
// PutApiUsersUserIdJSONRequestBody defines body for PutApiUsersUserId for application/json ContentType.
type PutApiUsersUserIdJSONRequestBody = UserRequest
// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType.
type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest

View File

@@ -62,7 +62,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
groupID, ok := vars["id"]
groupID, ok := vars["groupId"]
if !ok {
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
return
@@ -88,7 +88,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return
}
var req api.PutApiGroupsIdJSONRequestBody
var req api.PutApiGroupsGroupIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -121,110 +121,6 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(w, toGroupResponse(account, &group))
}
// PatchGroup handles patch updates to a group identified by a given ID
func (h *GroupsHandler) PatchGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
groupID := vars["id"]
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
_, ok := account.Groups[groupID]
if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find group ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
util.WriteError(err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
}
var req api.PatchApiGroupsIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if len(req) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return
}
var operations []server.GroupUpdateOperation
for _, patch := range req {
switch patch.Path {
case api.GroupPatchOperationPathName:
if patch.Op != api.GroupPatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"name field only accepts replace operation, got %s", patch.Op), w)
return
}
if len(patch.Value) == 0 || patch.Value[0] == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
operations = append(operations, server.GroupUpdateOperation{
Type: server.UpdateGroupName,
Values: patch.Value,
})
case api.GroupPatchOperationPathPeers:
switch patch.Op {
case api.GroupPatchOperationOpReplace:
peerKeys := peerIPsToKeys(account, &patch.Value)
operations = append(operations, server.GroupUpdateOperation{
Type: server.UpdateGroupPeers,
Values: peerKeys,
})
case api.GroupPatchOperationOpRemove:
peerKeys := peerIPsToKeys(account, &patch.Value)
operations = append(operations, server.GroupUpdateOperation{
Type: server.RemovePeersFromGroup,
Values: peerKeys,
})
case api.GroupPatchOperationOpAdd:
operations = append(operations, server.GroupUpdateOperation{
Type: server.InsertPeersToGroup,
Values: patch.Value,
})
default:
util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation, \"%v\", for PeersHandler field", patch.Op), w)
return
}
default:
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return
}
}
group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toGroupResponse(account, group))
}
// CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
@@ -277,7 +173,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
}
aID := account.Id
groupID := mux.Vars(r)["id"]
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
@@ -314,7 +210,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
groupID := mux.Vars(r)["id"]
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
@@ -335,29 +231,6 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
}
}
func peerIPsToKeys(account *server.Account, peerIPs *[]string) []string {
var mappedPeerKeys []string
if peerIPs == nil {
return mappedPeerKeys
}
peersChecked := make(map[string]struct{})
for _, requestPeersIP := range *peerIPs {
_, ok := peersChecked[requestPeersIP]
if ok {
continue
}
peersChecked[requestPeersIP] = struct{}{}
for _, accountPeer := range account.Peers {
if accountPeer.IP.String() == requestPeersIP {
mappedPeerKeys = append(mappedPeerKeys, accountPeer.Key)
}
}
}
return mappedPeerKeys
}
func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
cache := make(map[string]api.PeerMinimum)
gr := api.Group{

View File

@@ -136,7 +136,7 @@ func TestGetGroup(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/groups/{id}", p.GetGroup).Methods("GET")
router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -230,53 +230,6 @@ func TestWriteGroup(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "Write Group PATCH Name OK",
requestType: http.MethodPatch,
requestPath: "/api/groups/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"name","value":["Default POSTed Group"]}]`)),
expectedStatus: http.StatusOK,
expectedGroup: &api.Group{
Id: "id-existed",
Name: "Default POSTed Group",
},
},
{
name: "Write Group PATCH Invalid Name OP",
requestType: http.MethodPatch,
requestPath: "/api/groups/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "Write Group PATCH Invalid Name",
requestType: http.MethodPatch,
requestPath: "/api/groups/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"name","value":[]}]`)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "Write Group PATCH PeersHandler OK",
requestType: http.MethodPatch,
requestPath: "/api/groups/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"peers","value":["100.100.100.100","200.200.200.200"]}]`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedGroup: &api.Group{
Id: "id-existed",
PeersCount: 2,
Peers: []api.PeerMinimum{
{Id: "peer-A-ID"},
{Id: "peer-B-ID"},
},
},
},
}
adminUser := server.NewAdminUser("test_user")
@@ -289,8 +242,7 @@ func TestWriteGroup(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST")
router.HandleFunc("/api/groups/{id}", p.UpdateGroup).Methods("PUT")
router.HandleFunc("/api/groups/{id}", p.PatchGroup).Methods("PATCH")
router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
acMiddleware := middleware.NewAccessControl(
authCfg.Audience,
authCfg.UserIDClaim,
accountManager.IsUserAdmin)
accountManager.GetUser)
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()
@@ -96,22 +96,22 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
func (apiHandler *apiHandler) addAccountsEndpoint() {
accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/accounts/{id}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addPeersEndpoint() {
peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersEndpoint() {
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{id}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
}
@@ -127,56 +127,53 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() {
keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{id}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{id}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
}
func (apiHandler *apiHandler) addRulesEndpoint() {
rulesHandler := NewRulesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/rules", rulesHandler.GetAllRules).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/rules", rulesHandler.CreateRule).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.UpdateRule).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.GetRule).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.DeleteRule).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.UpdateRule).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.GetRule).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.DeleteRule).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addPoliciesEndpoint() {
policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addGroupsEndpoint() {
groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{id}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{id}", groupsHandler.PatchGroup).Methods("PATCH", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{id}", groupsHandler.GetGroup).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{id}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addRoutesEndpoint() {
routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{id}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{id}", routesHandler.PatchRoute).Methods("PATCH", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{id}", routesHandler.GetRoute).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{id}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addDNSNameserversEndpoint() {
nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroup).Methods("PATCH", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{id}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addDNSSettingEndpoint() {

View File

@@ -6,28 +6,30 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
type IsUserAdminFunc func(userID string) (bool, error)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
isUserAdmin IsUserAdminFunc
claimsExtract jwtclaims.ClaimsExtractor
getUser GetUser
}
// NewAccessControl instance constructor
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
return &AccessControl{
isUserAdmin: isUserAdmin,
claimsExtract: *jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
),
getUser: getUser,
}
}
@@ -37,23 +39,29 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := a.claimsExtract.FromRequestContext(r)
ok, err := a.isUserAdmin(claims.UserId)
user, err := a.getUser(claims)
if err != nil {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if !ok {
if user.IsBlocked() {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
return
}
if !user.IsAdmin() {
switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
if err != nil {
log.Debugf("Regex failed")
log.Debugf("regex failed")
util.WriteError(status.Errorf(status.Internal, ""), w)
return
}
if ok {
log.Debugf("Valid Path")
log.Debugf("valid Path")
h.ServeHTTP(w, r)
return
}

View File

@@ -99,13 +99,13 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
return
}
nsGroupID := mux.Vars(r)["id"]
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
var req api.PutApiDnsNameserversIdJSONRequestBody
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -140,88 +140,6 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
util.WriteJSONObject(w, &resp)
}
// PatchNameserverGroup handles patch updates to a nameserver group identified by a given ID
func (h *NameserversHandler) PatchNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
var req api.PatchApiDnsNameserversIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
var operations []server.NameServerGroupUpdateOperation
for _, patch := range req {
if patch.Op != api.NameserverGroupPatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"nameserver groups only accepts replace operations, got %s", patch.Op), w)
return
}
switch patch.Path {
case api.NameserverGroupPatchOperationPathName:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupName,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathDescription:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupDescription,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathPrimary:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupPrimary,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathDomains:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupDomains,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathNameservers:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupNameServers,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathGroups:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupGroups,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathEnabled:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupEnabled,
Values: patch.Value,
})
default:
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return
}
}
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, user.Id, operations)
if err != nil {
util.WriteError(err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
util.WriteJSONObject(w, &resp)
}
// DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
@@ -231,7 +149,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return
}
nsGroupID := mux.Vars(r)["id"]
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
@@ -256,7 +174,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
return
}
nsGroupID := mux.Vars(r)["id"]
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return

View File

@@ -227,31 +227,6 @@ func TestNameserversHandlers(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "PATCH OK",
requestType: http.MethodPatch,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: baseExistingNSGroup.Name,
Description: "NewDesc",
Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers,
Groups: baseExistingNSGroup.Groups,
Enabled: baseExistingNSGroup.Enabled,
Primary: baseExistingNSGroup.Primary,
},
},
{
name: "PATCH Invalid Nameserver Group OK",
requestType: http.MethodPatch,
requestPath: "/api/dns/nameservers/" + notFoundRouteID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
}
p := initNameserversTestData()
@@ -262,11 +237,10 @@ func TestNameserversHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{id}", p.GetNameserverGroup).Methods("GET")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{id}", p.DeleteNameserverGroup).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{id}", p.UpdateNameserverGroup).Methods("PUT")
router.HandleFunc("/api/dns/nameservers/{id}", p.PatchNameserverGroup).Methods("PATCH")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -63,7 +63,7 @@ var testAccount = &server.Account{
func initPATTestData() *PATHandler {
return &PATHandler{
accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@@ -79,7 +79,7 @@ func initPATTestData() *PATHandler {
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil
},
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error {
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@@ -91,7 +91,7 @@ func initPATTestData() *PATHandler {
}
return nil
},
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@@ -103,7 +103,7 @@ func initPATTestData() *PATHandler {
}
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
},
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}

View File

@@ -42,7 +42,7 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
}
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PutApiPeersIdJSONBody{}
req := &api.PutApiPeersPeerIdJSONBody{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -78,7 +78,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
return
}
vars := mux.Vars(r)
peerID := vars["id"]
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return

View File

@@ -146,8 +146,8 @@ func TestGetPeers(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
router.HandleFunc("/api/peers/{id}", p.HandlePeer).Methods("GET")
router.HandleFunc("/api/peers/{id}", p.HandlePeer).Methods("PUT")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -60,7 +60,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
policyID := vars["id"]
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
@@ -78,7 +78,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return
}
var req api.PutApiPoliciesIdJSONRequestBody
var req api.PutApiPoliciesPolicyIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -214,7 +214,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
aID := account.Id
vars := mux.Vars(r)
policyID := vars["id"]
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
@@ -240,7 +240,7 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
vars := mux.Vars(r)
policyID := vars["id"]
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return

View File

@@ -103,7 +103,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
routeID := vars["id"]
routeID := vars["routeId"]
if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
@@ -115,7 +115,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return
}
var req api.PutApiRoutesIdJSONRequestBody
var req api.PutApiRoutesRouteIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -159,147 +159,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(w, &resp)
}
// PatchRoute handles patch updates to a route identified by a given ID
func (h *RoutesHandler) PatchRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
routeID := vars["id"]
if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
var req api.PatchApiRoutesIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if len(req) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return
}
var operations []server.RouteUpdateOperation
for _, patch := range req {
switch patch.Path {
case api.RoutePatchOperationPathNetwork:
if patch.Op != api.RoutePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"network field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
Type: server.UpdateRouteNetwork,
Values: patch.Value,
})
case api.RoutePatchOperationPathDescription:
if patch.Op != api.RoutePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"description field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
Type: server.UpdateRouteDescription,
Values: patch.Value,
})
case api.RoutePatchOperationPathNetworkId:
if patch.Op != api.RoutePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"network Identifier field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
Type: server.UpdateRouteNetworkIdentifier,
Values: patch.Value,
})
case api.RoutePatchOperationPathPeer:
if patch.Op != api.RoutePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"peer field only accepts replace operation, got %s", patch.Op), w)
return
}
if len(patch.Value) > 1 {
util.WriteError(status.Errorf(status.InvalidArgument,
"value field only accepts 1 value, got %d", len(patch.Value)), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
Type: server.UpdateRoutePeer,
Values: patch.Value,
})
case api.RoutePatchOperationPathMetric:
if patch.Op != api.RoutePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"metric field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
Type: server.UpdateRouteMetric,
Values: patch.Value,
})
case api.RoutePatchOperationPathMasquerade:
if patch.Op != api.RoutePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"masquerade field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
Type: server.UpdateRouteMasquerade,
Values: patch.Value,
})
case api.RoutePatchOperationPathEnabled:
if patch.Op != api.RoutePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"enabled field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
Type: server.UpdateRouteEnabled,
Values: patch.Value,
})
case api.RoutePatchOperationPathGroups:
if patch.Op != api.RoutePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"groups field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
Type: server.UpdateRouteGroups,
Values: patch.Value,
})
default:
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return
}
}
root, err := h.accountManager.UpdateRoute(account.Id, routeID, operations)
if err != nil {
util.WriteError(err, w)
return
}
resp := toRouteResponse(root)
util.WriteJSONObject(w, &resp)
}
// DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
@@ -309,7 +168,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return
}
routeID := mux.Vars(r)["id"]
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
@@ -333,7 +192,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return
}
routeID := mux.Vars(r)["id"]
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return

View File

@@ -288,61 +288,6 @@ func TestRoutesHandlers(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "PATCH Description OK",
requestType: http.MethodPatch,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "NewDesc",
NetworkId: "awesomeNet",
Network: baseExistingRoute.Network.String(),
NetworkType: route.IPv4NetworkString,
Masquerade: baseExistingRoute.Masquerade,
Enabled: baseExistingRoute.Enabled,
Metric: baseExistingRoute.Metric,
Groups: baseExistingRoute.Groups,
},
},
{
name: "PATCH Peer OK",
requestType: http.MethodPatch,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", existingPeerID)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "NewDesc",
NetworkId: "awesomeNet",
Network: baseExistingRoute.Network.String(),
NetworkType: route.IPv4NetworkString,
Peer: existingPeerID,
Masquerade: baseExistingRoute.Masquerade,
Enabled: baseExistingRoute.Enabled,
Metric: baseExistingRoute.Metric,
Groups: baseExistingRoute.Groups,
},
},
{
name: "PATCH Not Found Peer",
requestType: http.MethodPatch,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "PATCH Not Found Route",
requestType: http.MethodPatch,
requestPath: "/api/routes/" + notFoundRouteID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"network\",\"value\":[\"192.168.0.0/34\"]}]"),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
}
p := initRoutesTestData()
@@ -353,11 +298,10 @@ func TestRoutesHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/routes/{id}", p.GetRoute).Methods("GET")
router.HandleFunc("/api/routes/{id}", p.DeleteRoute).Methods("DELETE")
router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE")
router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST")
router.HandleFunc("/api/routes/{id}", p.UpdateRoute).Methods("PUT")
router.HandleFunc("/api/routes/{id}", p.PatchRoute).Methods("PATCH")
router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -65,7 +65,7 @@ func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
ruleID := vars["id"]
ruleID := vars["ruleId"]
if len(ruleID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
@@ -77,7 +77,7 @@ func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
return
}
var req api.PutApiRulesIdJSONRequestBody
var req api.PutApiRulesRuleIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -210,7 +210,7 @@ func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) {
}
aID := account.Id
rID := mux.Vars(r)["id"]
rID := mux.Vars(r)["ruleId"]
if len(rID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
@@ -236,7 +236,7 @@ func (h *RulesHandler) GetRule(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
ruleID := mux.Vars(r)["id"]
ruleID := mux.Vars(r)["ruleId"]
if len(ruleID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return

View File

@@ -133,7 +133,7 @@ func TestRulesGetRule(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/rules/{id}", p.GetRule).Methods("GET")
router.HandleFunc("/api/rules/{ruleId}", p.GetRule).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -235,7 +235,7 @@ func TestRulesWriteRule(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/rules", p.CreateRule).Methods("POST")
router.HandleFunc("/api/rules/{id}", p.UpdateRule).Methods("PUT")
router.HandleFunc("/api/rules/{ruleId}", p.UpdateRule).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -84,7 +84,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
keyID := vars["id"]
keyID := vars["keyId"]
if len(keyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return
@@ -109,13 +109,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
}
vars := mux.Vars(r)
keyID := vars["id"]
keyID := vars["keyId"]
if len(keyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return
}
req := &api.PutApiSetupKeysIdJSONRequestBody{}
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)

View File

@@ -174,8 +174,8 @@ func TestSetupKeysHandlers(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{id}", handler.GetSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{id}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -48,19 +48,24 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
userID := vars["id"]
userID := vars["userId"]
if len(userID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
req := &api.PutApiUsersIdJSONRequestBody{}
req := &api.PutApiUsersUserIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.AutoGroups == nil {
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
return
}
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
@@ -71,7 +76,9 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
Id: userID,
Role: userRole,
AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked,
})
if err != nil {
util.WriteError(err, w)
return
@@ -94,7 +101,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
targetUserID := vars["id"]
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
@@ -214,7 +221,11 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
case "invited":
userStatus = api.UserStatusInvited
default:
userStatus = api.UserStatusDisabled
userStatus = api.UserStatusBlocked
}
if user.IsBlocked {
userStatus = api.UserStatusBlocked
}
isCurrent := user.ID == currenUserID
@@ -227,5 +238,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
Status: userStatus,
IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser,
IsBlocked: user.IsBlocked,
}
}

View File

@@ -3,6 +3,7 @@ package http
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -31,16 +32,19 @@ var usersTestAccount = &server.Account{
Id: existingUserID,
Role: "admin",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
},
regularUserID: {
Id: regularUserID,
Role: "user",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
},
serviceUserID: {
Id: serviceUserID,
Role: "user",
IsServiceUser: true,
AutoGroups: []string{"group_1"},
},
},
}
@@ -70,7 +74,7 @@ func initUsersTestData() *UsersHandler {
}
return key, nil
},
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
@@ -79,6 +83,21 @@ func initUsersTestData() *UsersHandler {
}
return nil
},
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
if update.Id == notFoundUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
}
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
info, err := update.Copy().ToUserInfo(nil)
if err != nil {
return nil, err
}
return info, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
@@ -145,6 +164,122 @@ func TestGetUsers(t *testing.T) {
}
}
func TestUpdateUser(t *testing.T) {
tt := []struct {
name string
expectedStatusCode int
requestType string
requestPath string
requestBody io.Reader
expectedUserID string
expectedRole string
expectedStatus string
expectedBlocked bool
expectedIsServiceUser bool
expectedGroups []string
}{
{
name: "Update_Block_User",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: true,
expectedRole: "user",
expectedStatus: "blocked",
expectedGroups: []string{"group_1"},
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
},
{
name: "Update_Change_Role_To_Admin",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_1"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
},
{
name: "Update_Groups",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_2", "group_3"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
},
{
name: "Should_Fail_Because_AutoGroups_Is_Absent",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusBadRequest,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_2", "group_3"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatusCode {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
if tc.expectedStatusCode == 200 {
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
respBody := &api.User{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("response content is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedUserID, respBody.Id)
assert.Equal(t, tc.expectedRole, respBody.Role)
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
for _, expectedGroup := range tc.expectedGroups {
exists := false
for _, actualGroup := range respBody.AutoGroups {
if expectedGroup == actualGroup {
exists = true
}
}
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
}
}
})
}
}
func TestCreateUser(t *testing.T) {
name := "name"
email := "email"
@@ -219,21 +354,21 @@ func TestDeleteUser(t *testing.T) {
name: "Delete Regular User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + regularUserID,
requestVars: map[string]string{"id": regularUserID},
requestVars: map[string]string{"userId": regularUserID},
expectedStatus: http.StatusForbidden,
},
{
name: "Delete Service User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + serviceUserID,
requestVars: map[string]string{"id": serviceUserID},
requestVars: map[string]string{"userId": serviceUserID},
expectedStatus: http.StatusOK,
},
{
name: "Delete Not Existing User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + notFoundUserID,
requestVars: map[string]string{"id": notFoundUserID},
requestVars: map[string]string{"userId": notFoundUserID},
expectedStatus: http.StatusNotFound,
},
}

View File

@@ -32,10 +32,10 @@ type Auth0Manager struct {
// Auth0ClientConfig auth0 manager client configurations
type Auth0ClientConfig struct {
Audience string
AuthIssuer string
AuthIssuer string `json:"-"`
ClientID string
ClientSecret string
GrantType string
GrantType string `json:"-"`
}
// auth0JWTRequest payload struct to request a JWT Token
@@ -110,7 +110,8 @@ type auth0Profile struct {
}
// NewAuth0Manager creates a new instance of the Auth0Manager
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
func NewAuth0Manager(oidcConfig OIDCConfig, config Auth0ClientConfig,
appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -121,17 +122,19 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
}
helper := JsonParser{}
config.AuthIssuer = oidcConfig.TokenEndpoint
config.GrantType = "client_credentials"
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.Audience == "" || config.AuthIssuer == "" {
return nil, fmt.Errorf("auth0 idp configuration is not complete")
if config.ClientID == "" {
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, clientID is missing")
}
if config.GrantType != "client_credentials" {
return nil, fmt.Errorf("auth0 idp configuration failed. Grant Type should be client_credentials")
if config.ClientSecret == "" {
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
}
if !strings.HasPrefix(strings.ToLower(config.AuthIssuer), "https://") {
return nil, fmt.Errorf("auth0 idp configuration failed. AuthIssuer should contain https://")
if config.Audience == "" {
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
}
credentials := &Auth0Credentials{

View File

@@ -459,26 +459,9 @@ func TestNewAuth0Manager(t *testing.T) {
testCase3Config := defaultTestConfig
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
testCase3 := test{
name: "Wrong Auth Issuer Format",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when wrong auth issuer format",
}
testCase4Config := defaultTestConfig
testCase4Config.GrantType = "spa"
testCase4 := test{
name: "Wrong Grant Type",
inputConfig: testCase4Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when wrong grant type",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
for _, testCase := range []test{testCase1, testCase2} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
_, err := NewAuth0Manager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}

View File

@@ -0,0 +1,671 @@
package idp
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
)
const (
// azure extension properties template
wtAccountIDTpl = "extension_%s_wt_account_id"
wtPendingInviteTpl = "extension_%s_wt_pending_invite"
profileFields = "id,displayName,mail,userPrincipalName"
extensionFields = "id,name,targetObjects"
)
// AzureManager azure manager client instance.
type AzureManager struct {
ClientID string
ObjectID string
GraphAPIEndpoint string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// AzureClientConfig azure manager client configurations.
type AzureClientConfig struct {
ClientID string
ClientSecret string
ObjectID string
GraphAPIEndpoint string `json:"-"`
TokenEndpoint string `json:"-"`
GrantType string `json:"-"`
}
// AzureCredentials azure authentication information.
type AzureCredentials struct {
clientConfig AzureClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// azureProfile represents an azure user profile.
type azureProfile map[string]any
// passwordProfile represent authentication method for,
// newly created user profile.
type passwordProfile struct {
ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"`
Password string `json:"password"`
}
// azureExtension represent custom attribute,
// that can be added to user objects in Azure Active Directory (AD).
type azureExtension struct {
Name string `json:"name"`
DataType string `json:"dataType"`
TargetObjects []string `json:"targetObjects"`
}
// NewAzureManager creates a new instance of the AzureManager.
func NewAzureManager(oidcConfig OIDCConfig, config AzureClientConfig,
appMetrics telemetry.AppMetrics) (*AzureManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
config.TokenEndpoint = oidcConfig.TokenEndpoint
config.GraphAPIEndpoint = "https://graph.microsoft.com"
config.GrantType = "client_credentials"
if config.ClientID == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
}
if config.ClientSecret == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
}
if config.ObjectID == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
}
credentials := &AzureCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
manager := &AzureManager{
ObjectID: config.ObjectID,
ClientID: config.ClientID,
GraphAPIEndpoint: config.GraphAPIEndpoint,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}
err := manager.configureAppMetadata()
if err != nil {
return nil, err
}
return manager, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
func (ac *AzureCredentials) jwtStillValid() bool {
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token.
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("client_secret", ac.clientConfig.ClientSecret)
data.Set("grant_type", ac.clientConfig.GrantType)
data.Set("scope", "https://graph.microsoft.com/.default")
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for azure idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
}
return resp, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = ac.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = ac.helper.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the azure Management API.
func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountAuthenticate()
}
// reuse the token without requesting a new one if it is not expired,
// and if expiry time is sufficient time available to make a request.
if ac.jwtStillValid() {
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken()
if err != nil {
return ac.jwtToken, err
}
defer resp.Body.Close()
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
if err != nil {
return ac.jwtToken, err
}
ac.jwtToken = jwtToken
return ac.jwtToken, nil
}
// CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID)
if err != nil {
return nil, err
}
body, err := am.post("users", payload)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
profile[wtAccountIDField] = accountID
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
profile[wtPendingInviteField] = true
return profile.userData(am.ClientID), nil
}
// GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
body, err := am.get("users/"+userID, q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
return profile.userData(am.ClientID), nil
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
body, err := am.get("users/"+email, q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
users = append(users, profile.userData(am.ClientID))
return users, nil
}
// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID))
body, err := am.get("users", q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Value {
users = append(users, profile.userData(am.ClientID))
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
body, err := am.get("users", q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Value {
userData := profile.userData(am.ClientID)
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
data, err := am.helper.Marshal(map[string]any{
wtAccountIDField: appMetadata.WTAccountID,
wtPendingInviteField: appMetadata.WTPendingInvite,
})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID)
req, err := http.NewRequest(http.MethodPatch, reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating idp metadata for user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
}
return nil
}
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
q := url.Values{}
q.Add("$select", extensionFields)
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.get(resource, q)
if err != nil {
return nil, err
}
var extensions struct{ Value []azureExtension }
err = am.helper.Unmarshal(body, &extensions)
if err != nil {
return nil, err
}
return extensions.Value, nil
}
func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) {
extension := azureExtension{
Name: name,
DataType: "string",
TargetObjects: []string{"User"},
}
payload, err := am.helper.Marshal(extension)
if err != nil {
return nil, err
}
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.post(resource, string(payload))
if err != nil {
return nil, err
}
var userExtension azureExtension
err = am.helper.Unmarshal(body, &userExtension)
if err != nil {
return nil, err
}
return &userExtension, nil
}
// configureAppMetadata sets up app metadata extensions if they do not exists.
func (am *AzureManager) configureAppMetadata() error {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
extensions, err := am.getUserExtensions()
if err != nil {
return err
}
// If the wt_account_id extension does not already exist, create it.
if !hasExtension(extensions, wtAccountIDField) {
_, err = am.createUserExtension(wtAccountID)
if err != nil {
return err
}
}
// If the wt_pending_invite extension does not already exist, create it.
if !hasExtension(extensions, wtPendingInviteField) {
_, err = am.createUserExtension(wtPendingInvite)
if err != nil {
return err
}
}
return nil
}
// get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// post perform Post requests.
func (am *AzureManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s", am.GraphAPIEndpoint, resource)
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// userData construct user data from keycloak profile.
func (ap azureProfile) userData(clientID string) *UserData {
id, ok := ap["id"].(string)
if !ok {
id = ""
}
email, ok := ap["userPrincipalName"].(string)
if !ok {
email = ""
}
name, ok := ap["displayName"].(string)
if !ok {
name = ""
}
accountIDField := extensionName(wtAccountIDTpl, clientID)
accountID, ok := ap[accountIDField].(string)
if !ok {
accountID = ""
}
pendingInviteField := extensionName(wtPendingInviteTpl, clientID)
pendingInvite, ok := ap[pendingInviteField].(bool)
if !ok {
pendingInvite = false
}
return &UserData{
Email: email,
Name: name,
ID: id,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pendingInvite,
},
}
}
func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, clientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID)
req := &azureProfile{
"accountEnabled": true,
"displayName": name,
"mailNickName": strings.Join(strings.Split(name, " "), ""),
"userPrincipalName": email,
"passwordProfile": passwordProfile{
ForceChangePasswordNextSignIn: true,
Password: GeneratePassword(8, 1, 1, 1),
},
wtAccountIDField: accountID,
wtPendingInviteField: true,
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func extensionName(extensionTpl, clientID string) string {
clientID = strings.ReplaceAll(clientID, "-", "")
return fmt.Sprintf(extensionTpl, clientID)
}
// hasExtension checks whether a given extension by name,
// exists in an list of extensions.
func hasExtension(extensions []azureExtension, name string) bool {
for _, ext := range extensions {
if ext.Name == name {
return true
}
}
return false
}

View File

@@ -0,0 +1,329 @@
package idp
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type mockAzureCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockAzureCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestAzureJwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
inputTime time.Time
expectedResult bool
message string
}
jwtStillValidTestCase1 := jwtStillValidTest{
name: "JWT still valid",
inputTime: time.Now().Add(10 * time.Second),
expectedResult: true,
message: "should be true",
}
jwtStillValidTestCase2 := jwtStillValidTest{
name: "JWT is invalid",
inputTime: time.Now(),
expectedResult: false,
message: "should be false",
}
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
config := AzureClientConfig{}
creds := AzureCredentials{
clientConfig: config,
}
creds.jwtToken.expiresInTime = testCase.inputTime
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
})
}
}
func TestAzureAuthenticate(t *testing.T) {
type authenticateTest struct {
name string
inputCode int
inputResBody string
inputExpireToken time.Time
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
authenticateTestCase1 := authenticateTest{
name: "Get Cached token",
inputExpireToken: time.Now().Add(30 * time.Second),
helper: JsonParser{},
expectedFuncExitErrDiff: nil,
expectedCode: 200,
expectedToken: "",
}
authenticateTestCase2 := authenticateTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
authenticateTestCase3 := authenticateTest{
name: "Get Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get azure token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := AzureClientConfig{}
creds := AzureCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
})
}
}
func TestAzureUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &AzureManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestAzureProfile(t *testing.T) {
type azureProfileTest struct {
name string
clientID string
invite bool
inputProfile azureProfile
expectedUserData UserData
}
azureProfileTestCase1 := azureProfileTest{
name: "Good Request",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test1",
"displayName": "John Doe",
"userPrincipalName": "test1@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
},
expectedUserData: UserData{
Email: "test1@test.com",
Name: "John Doe",
ID: "test1",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase2 := azureProfileTest{
name: "Missing User ID",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: true,
inputProfile: azureProfile{
"displayName": "John Doe",
"userPrincipalName": "test2@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true,
},
expectedUserData: UserData{
Email: "test2@test.com",
Name: "John Doe",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase3 := azureProfileTest{
name: "Missing User Name",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test3",
"userPrincipalName": "test3@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
},
expectedUserData: UserData{
ID: "test3",
Email: "test3@test.com",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase4 := azureProfileTest{
name: "Missing Extension Fields",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test4",
"displayName": "John Doe",
"userPrincipalName": "test4@test.com",
},
expectedUserData: UserData{
ID: "test4",
Name: "John Doe",
Email: "test4@test.com",
AppMetadata: AppMetadata{},
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData(testCase.clientID)
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
})
}
}

View File

@@ -19,11 +19,21 @@ type Manager interface {
GetUserByEmail(email string) ([]*UserData, error)
}
// OIDCConfig specifies configuration for OpenID Connect provider
// These configurations are automatically loaded from the OIDC endpoint
type OIDCConfig struct {
Issuer string
TokenEndpoint string
}
// Config an idp configuration struct to be loaded from management server's config file
type Config struct {
ManagerType string
OIDCConfig OIDCConfig `json:"-"`
Auth0ClientCredentials Auth0ClientConfig
AzureClientCredentials AzureClientConfig
KeycloakClientCredentials KeycloakClientConfig
ZitadelClientCredentials ZitadelClientConfig
}
// ManagerCredentials interface that authenticates using the credential of each type of idp
@@ -72,9 +82,13 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
case "none", "":
return nil, nil
case "auth0":
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
return NewAuth0Manager(config.OIDCConfig, config.Auth0ClientCredentials, appMetrics)
case "azure":
return NewAzureManager(config.OIDCConfig, config.AzureClientCredentials, appMetrics)
case "keycloak":
return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics)
return NewKeycloakManager(config.OIDCConfig, config.KeycloakClientCredentials, appMetrics)
case "zitadel":
return NewZitadelManager(config.OIDCConfig, config.ZitadelClientCredentials, appMetrics)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
}

View File

@@ -37,8 +37,8 @@ type KeycloakClientConfig struct {
ClientID string
ClientSecret string
AdminEndpoint string
TokenEndpoint string
GrantType string
TokenEndpoint string `json:"-"`
GrantType string `json:"-"`
}
// KeycloakCredentials keycloak authentication information.
@@ -82,7 +82,8 @@ type keycloakProfile struct {
}
// NewKeycloakManager creates a new instance of the KeycloakManager.
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
func NewKeycloakManager(oidcConfig OIDCConfig, config KeycloakClientConfig,
appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -92,13 +93,19 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
}
helper := JsonParser{}
config.TokenEndpoint = oidcConfig.TokenEndpoint
config.GrantType = "client_credentials"
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" {
return nil, fmt.Errorf("keycloak idp configuration is not complete")
if config.ClientID == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
}
if config.GrantType != "client_credentials" {
return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials")
if config.ClientSecret == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
}
if config.AdminEndpoint == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
}
credentials := &KeycloakCredentials{

View File

@@ -46,19 +46,19 @@ func TestNewKeycloakManager(t *testing.T) {
assertErrFuncMessage: "should return error when field empty",
}
testCase5Config := defaultTestConfig
testCase5Config.GrantType = "authorization_code"
testCase3Config := defaultTestConfig
testCase3Config.ClientSecret = ""
testCase5 := test{
name: "Wrong GrantType",
inputConfig: testCase5Config,
testCase3 := test{
name: "Missing ClientSecret Configuration",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when wrong grant type",
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2, testCase5} {
for _, testCase := range []test{testCase1, testCase2, testCase3} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
_, err := NewKeycloakManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}

View File

@@ -0,0 +1,609 @@
package idp
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
)
// ZitadelManager zitadel manager client instance.
type ZitadelManager struct {
managementEndpoint string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// ZitadelClientConfig zitadel manager client configurations.
type ZitadelClientConfig struct {
ClientID string
ClientSecret string
GrantType string `json:"-"`
TokenEndpoint string `json:"-"`
ManagementEndpoint string `json:"-"`
}
// ZitadelCredentials zitadel authentication information.
type ZitadelCredentials struct {
clientConfig ZitadelClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// zitadelEmail specifies details of a user email.
type zitadelEmail struct {
Email string `json:"email"`
IsEmailVerified bool `json:"isEmailVerified"`
}
// zitadelUserInfo specifies user information.
type zitadelUserInfo struct {
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
DisplayName string `json:"displayName"`
}
// zitadelUser specifies profile details for user account.
type zitadelUser struct {
UserName string `json:"userName,omitempty"`
Profile zitadelUserInfo `json:"profile"`
Email zitadelEmail `json:"email"`
}
type zitadelAttributes map[string][]map[string]any
// zitadelMetadata holds additional user data.
type zitadelMetadata struct {
Key string `json:"key"`
Value string `json:"value"`
}
// zitadelProfile represents an zitadel user profile response.
type zitadelProfile struct {
ID string `json:"id"`
State string `json:"state"`
UserName string `json:"userName"`
PreferredLoginName string `json:"preferredLoginName"`
LoginNames []string `json:"loginNames"`
Human *zitadelUser `json:"human"`
Metadata []zitadelMetadata
}
// NewZitadelManager creates a new instance of the ZitadelManager.
func NewZitadelManager(oidcConfig OIDCConfig, config ZitadelClientConfig,
appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
config.TokenEndpoint = oidcConfig.TokenEndpoint
config.ManagementEndpoint = fmt.Sprintf("%s/management/v1", oidcConfig.Issuer)
config.GrantType = "client_credentials"
if config.ClientID == "" {
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, clientID is missing")
}
if config.ClientSecret == "" {
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, ClientSecret is missing")
}
credentials := &ZitadelCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &ZitadelManager{
managementEndpoint: config.ManagementEndpoint,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from zitadel.
func (zc *ZitadelCredentials) jwtStillValid() bool {
return !zc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(zc.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token.
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
data := url.Values{}
data.Set("client_id", zc.clientConfig.ClientID)
data.Set("client_secret", zc.clientConfig.ClientSecret)
data.Set("grant_type", zc.clientConfig.GrantType)
data.Set("scope", "urn:zitadel:iam:org:project:id:zitadel:aud")
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, zc.clientConfig.TokenEndpoint, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for zitadel idp manager")
resp, err := zc.httpClient.Do(req)
if err != nil {
if zc.appMetrics != nil {
zc.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
}
return resp, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds.
func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = zc.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = zc.helper.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the Zitadel Management API.
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
zc.mux.Lock()
defer zc.mux.Unlock()
if zc.appMetrics != nil {
zc.appMetrics.IDPMetrics().CountAuthenticate()
}
// reuse the token without requesting a new one if it is not expired,
// and if expiry time is sufficient time available to make a request.
if zc.jwtStillValid() {
return zc.jwtToken, nil
}
resp, err := zc.requestJWTToken()
if err != nil {
return zc.jwtToken, err
}
defer resp.Body.Close()
jwtToken, err := zc.parseRequestJWTResponse(resp.Body)
if err != nil {
return zc.jwtToken, err
}
zc.jwtToken = jwtToken
return zc.jwtToken, nil
}
// CreateUser creates a new user in zitadel Idp and sends an invite.
func (zm *ZitadelManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
payload, err := buildZitadelCreateUserRequestPayload(email, name)
if err != nil {
return nil, err
}
body, err := zm.post("users/human/_import", payload)
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountCreateUser()
}
var result struct {
UserID string `json:"userId"`
}
err = zm.helper.Unmarshal(body, &result)
if err != nil {
return nil, err
}
invite := true
appMetadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
// Add metadata to new user
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
if err != nil {
return nil, err
}
return zm.GetUserDataByID(result.UserID, appMetadata)
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
searchByEmail := zitadelAttributes{
"queries": {
{
"emailQuery": map[string]any{
"emailAddress": email,
"method": "TEXT_QUERY_METHOD_EQUALS",
},
},
},
}
payload, err := zm.helper.Marshal(searchByEmail)
if err != nil {
return nil, err
}
body, err := zm.post("users/_search", string(payload))
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profiles struct{ Result []zitadelProfile }
err = zm.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Result {
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
users = append(users, profile.userData())
}
return users, nil
}
// GetUserDataByID requests user data from zitadel via ID.
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := zm.get("users/"+userID, nil)
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var profile struct{ User zitadelProfile }
err = zm.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
metadata, err := zm.getUserMetadata(userID)
if err != nil {
return nil, err
}
profile.User.Metadata = metadata
return profile.User.userData(), nil
}
// GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
accounts, err := zm.GetAllAccounts()
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetAccount()
}
return accounts[accountID], nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
body, err := zm.post("users/_search", "")
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
var profiles struct{ Result []zitadelProfile }
err = zm.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Result {
// fetch user metadata
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
userData := profile.userData()
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
if appMetadata.WTPendingInvite == nil {
appMetadata.WTPendingInvite = new(bool)
}
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
metadata := zitadelAttributes{
"metadata": {
{
"key": wtAccountID,
"value": wtAccountIDValue,
},
{
"key": wtPendingInvite,
"value": wtPendingInviteValue,
},
},
}
payload, err := zm.helper.Marshal(metadata)
if err != nil {
return err
}
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
_, err = zm.post(resource, string(payload))
if err != nil {
return err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
return nil
}
// getUserMetadata requests user metadata from zitadel via ID.
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
body, err := zm.post(resource, "")
if err != nil {
return nil, err
}
var metadata struct{ Result []zitadelMetadata }
err = zm.helper.Unmarshal(body, &metadata)
if err != nil {
return nil, err
}
return metadata.Result, nil
}
// post perform Post requests.
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 && resp.StatusCode != 201 {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// get perform Get requests.
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s?%s", zm.managementEndpoint, resource, q.Encode())
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// value returns string represented by the base64 string value.
func (zm zitadelMetadata) value() string {
value, err := base64.StdEncoding.DecodeString(zm.Value)
if err != nil {
return ""
}
return string(value)
}
// userData construct user data from zitadel profile.
func (zp zitadelProfile) userData() *UserData {
var (
email string
name string
wtAccountIDValue string
wtPendingInviteValue bool
)
for _, metadata := range zp.Metadata {
if metadata.Key == wtAccountID {
wtAccountIDValue = metadata.value()
}
if metadata.Key == wtPendingInvite {
value, err := strconv.ParseBool(metadata.value())
if err == nil {
wtPendingInviteValue = value
}
}
}
// Obtain the email for the human account and the login name,
// for the machine account.
if zp.Human != nil {
email = zp.Human.Email.Email
name = zp.Human.Profile.DisplayName
} else {
if len(zp.LoginNames) > 0 {
email = zp.LoginNames[0]
name = zp.LoginNames[0]
}
}
return &UserData{
Email: email,
Name: name,
ID: zp.ID,
AppMetadata: AppMetadata{
WTAccountID: wtAccountIDValue,
WTPendingInvite: &wtPendingInviteValue,
},
}
}
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
var firstName, lastName string
words := strings.Fields(name)
if n := len(words); n > 0 {
firstName = strings.Join(words[:n-1], " ")
lastName = words[n-1]
}
req := &zitadelUser{
UserName: name,
Profile: zitadelUserInfo{
FirstName: strings.TrimSpace(firstName),
LastName: strings.TrimSpace(lastName),
DisplayName: name,
},
Email: zitadelEmail{
Email: email,
IsEmailVerified: false,
},
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}

View File

@@ -0,0 +1,486 @@
package idp
import (
"fmt"
"io"
"strings"
"testing"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewZitadelManager(t *testing.T) {
type test struct {
name string
inputConfig ZitadelClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := ZitadelClientConfig{
ClientID: "client_id",
ClientSecret: "client_secret",
GrantType: "client_credentials",
TokenEndpoint: "http://localhost/oauth/v2/token",
ManagementEndpoint: "http://localhost/management/v1",
}
testCase1 := test{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
}
testCase2Config := defaultTestConfig
testCase2Config.ClientID = ""
testCase2 := test{
name: "Missing ClientID Configuration",
inputConfig: testCase2Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase3Config := defaultTestConfig
testCase3Config.ClientSecret = ""
testCase3 := test{
name: "Missing ClientSecret Configuration",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2, testCase3} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewZitadelManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}
}
type mockZitadelCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestZitadelRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct {
name string
inputCode int
inputRespBody string
helper ManagerHelper
expectedFuncExitErrDiff error
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
requestJWTTokenTesttCase1 := requestJWTTokenTest{
name: "Good JWT Response",
inputCode: 200,
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
}
requestJWTTokenTestCase2 := requestJWTTokenTest{
name: "Request Bad Status Code",
inputCode: 400,
inputRespBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
expectedToken: "",
}
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputRespBody,
code: testCase.inputCode,
}
config := ZitadelClientConfig{}
creds := ZitadelCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
resp, err := creds.requestJWTToken()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
} else {
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err, "unable to read the response body")
jwtToken := JWTToken{}
err = testCase.helper.Unmarshal(body, &jwtToken)
assert.NoError(t, err, "unable to parse the json input")
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
}
})
}
}
func TestZitadelParseRequestJWTResponse(t *testing.T) {
type parseRequestJWTResponseTest struct {
name string
inputRespBody string
helper ManagerHelper
expectedToken string
expectedExpiresIn int
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
exp := 100
token := newTestJWT(t, exp)
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
name: "Parse Good JWT Body",
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
expectedExpiresIn: exp,
assertErrFunc: assert.NoError,
assertErrFuncMessage: "no error was expected",
}
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
name: "Parse Bad json JWT Body",
inputRespBody: "",
helper: JsonParser{},
expectedToken: "",
expectedExpiresIn: 0,
assertErrFunc: assert.Error,
assertErrFuncMessage: "json error was expected",
}
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
config := ZitadelClientConfig{}
creds := ZitadelCredentials{
clientConfig: config,
helper: testCase.helper,
}
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
})
}
}
func TestZitadelJwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
inputTime time.Time
expectedResult bool
message string
}
jwtStillValidTestCase1 := jwtStillValidTest{
name: "JWT still valid",
inputTime: time.Now().Add(10 * time.Second),
expectedResult: true,
message: "should be true",
}
jwtStillValidTestCase2 := jwtStillValidTest{
name: "JWT is invalid",
inputTime: time.Now(),
expectedResult: false,
message: "should be false",
}
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
config := ZitadelClientConfig{}
creds := ZitadelCredentials{
clientConfig: config,
}
creds.jwtToken.expiresInTime = testCase.inputTime
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
})
}
}
func TestZitadelAuthenticate(t *testing.T) {
type authenticateTest struct {
name string
inputCode int
inputResBody string
inputExpireToken time.Time
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
authenticateTestCase1 := authenticateTest{
name: "Get Cached token",
inputExpireToken: time.Now().Add(30 * time.Second),
helper: JsonParser{},
expectedFuncExitErrDiff: nil,
expectedCode: 200,
expectedToken: "",
}
authenticateTestCase2 := authenticateTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
authenticateTestCase3 := authenticateTest{
name: "Get Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := ZitadelClientConfig{}
creds := ZitadelCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
})
}
}
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
appMetadata: appMetadata,
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &ZitadelManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestZitadelProfile(t *testing.T) {
type azureProfileTest struct {
name string
invite bool
inputProfile zitadelProfile
expectedUserData UserData
}
azureProfileTestCase1 := azureProfileTest{
name: "User Request",
invite: false,
inputProfile: zitadelProfile{
ID: "test1",
State: "USER_STATE_ACTIVE",
UserName: "test1@mail.com",
PreferredLoginName: "test1@mail.com",
LoginNames: []string{
"test1@mail.com",
},
Human: &zitadelUser{
Profile: zitadelUserInfo{
FirstName: "ZITADEL",
LastName: "Admin",
DisplayName: "ZITADEL Admin",
},
Email: zitadelEmail{
Email: "test1@mail.com",
IsEmailVerified: true,
},
},
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "ZmFsc2U=",
},
},
},
expectedUserData: UserData{
ID: "test1",
Name: "ZITADEL Admin",
Email: "test1@mail.com",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase2 := azureProfileTest{
name: "Service User Request",
invite: true,
inputProfile: zitadelProfile{
ID: "test2",
State: "USER_STATE_ACTIVE",
UserName: "machine",
PreferredLoginName: "machine",
LoginNames: []string{
"machine",
},
Human: nil,
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "dHJ1ZQ==",
},
},
},
expectedUserData: UserData{
ID: "test2",
Name: "machine",
Email: "machine",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData()
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
})
}
}

View File

@@ -12,6 +12,10 @@ import (
"fmt"
"math/big"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
@@ -45,7 +49,8 @@ type Options struct {
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
}
// JSONWebKey is a representation of a Jason Web Key
@@ -64,12 +69,13 @@ type JWTValidator struct {
}
// NewJWTValidator constructor
func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) {
func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation)
if err != nil {
return nil, err
}
var lock sync.Mutex
options := Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
@@ -89,6 +95,23 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string)
return token, errors.New("invalid issuer")
}
// If keys are rotated, verify the keys prior to token validation
if idpSignkeyRefreshEnabled {
// If the keys are invalid, retrieve new ones
if !keys.stillValid() {
lock.Lock()
defer lock.Unlock()
refreshedKeys, err := getPemKeys(keysLocation)
if err != nil {
log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
refreshedKeys = keys
}
keys = refreshedKeys
}
}
cert, err := getPemCert(token, keys)
if err != nil {
return nil, err
@@ -154,6 +177,11 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
return parsedToken, nil
}
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation)
if err != nil {
@@ -167,6 +195,10 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
return jwks, err
}
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, err
}
@@ -248,3 +280,26 @@ func convertExponentStringToInt(stringExponent string) (int, error) {
return int(exponent), nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
for _, directive := range directives {
directive = strings.TrimSpace(directive)
if strings.HasPrefix(directive, "max-age=") {
// Extract the max-age value
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
log.Debugf("error parsing max-age: %v", err)
return 0
}
return maxAge
}
}
return 0
}

View File

@@ -15,12 +15,11 @@ import (
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserIDFunc func(userID string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(userID string) (bool, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
AccountExistsFunc func(accountId string) (*bool, error)
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
@@ -61,11 +60,11 @@ type MockAccountManager struct {
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
@@ -113,14 +112,6 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
)
}
// GetAccountByUserID mock implementation of GetAccountByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserID(userID string) (*server.Account, error) {
if am.GetAccountByUserIDFunc != nil {
return am.GetAccountByUserIDFunc(userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserID is not implemented")
}
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
func (am *MockAccountManager) CreateSetupKey(
accountID string,
@@ -199,33 +190,33 @@ func (am *MockAccountManager) MarkPATUsed(pat string) error {
}
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil {
return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn)
return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
}
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if am.DeletePATFunc != nil {
return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID)
return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
}
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
}
// GetPAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if am.GetPATFunc != nil {
return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID)
return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
}
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if am.GetAllPATsFunc != nil {
return am.GetAllPATsFunc(accountID, executingUserID, targetUserID)
return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
}
@@ -394,12 +385,12 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
}
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
func (am *MockAccountManager) IsUserAdmin(userID string) (bool, error) {
if am.IsUserAdminFunc != nil {
return am.IsUserAdminFunc(userID)
// GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
if am.GetUserFunc != nil {
return am.GetUserFunc(claims)
}
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented")
}
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
@@ -502,9 +493,9 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
}
// DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil {
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
}
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
}

View File

@@ -605,6 +605,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
}
if peerLoginExpired(peer, account) {
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
@@ -644,6 +649,11 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
}
updateRemotePeers := false
if peerLoginExpired(peer, account) {
err = checkAuth(login.UserID, peer)
@@ -676,6 +686,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
}
func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error {
if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID)
if err != nil {
return status.Errorf(status.PermissionDenied, "user doesn't exist")
}
if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked")
}
}
return nil
}
func checkAuth(loginUserID string, peer *Peer) error {
if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided.

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
@@ -50,15 +51,22 @@ type User struct {
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string
PATs map[string]*PersonalAccessToken
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
}
// IsAdmin returns true if user is an admin, false otherwise
// IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
}
// IsAdmin returns true if the user is an admin, false otherwise
func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin
}
// toUserInfo converts a User object to a UserInfo object.
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
autoGroups := u.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
@@ -73,6 +81,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
}, nil
}
if userData.ID != u.Id {
@@ -92,6 +101,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
AutoGroups: autoGroups,
Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
}, nil
}
@@ -112,6 +122,7 @@ func (u *User) Copy() *User {
IsServiceUser: u.IsServiceUser,
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
}
}
@@ -137,7 +148,7 @@ func NewAdminUser(id string) *User {
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -146,7 +157,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if executingUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
@@ -165,7 +176,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
}
meta := map[string]any{"name": newUser.ServiceUserName}
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
am.storeEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
return &UserInfo{
ID: newUser.Id,
@@ -211,7 +222,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
}
if user != nil {
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
}
users, err := am.idpManager.GetUserByEmail(invite.Email)
@@ -220,7 +231,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
}
if len(users) > 0 {
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
}
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
@@ -248,12 +259,27 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
return newUser.toUserInfo(idpUser)
return newUser.ToUserInfo(idpUser)
}
// GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) {
account, _, err := am.GetAccountFromToken(claims)
if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
user, ok := account.Users[claims.UserId]
if !ok {
return nil, status.Errorf(status.NotFound, "user not found")
}
return user, nil
}
// DeleteUser deletes a user from the given account.
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -267,7 +293,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
return status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
@@ -280,7 +306,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
}
meta := map[string]any{"name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
delete(account.Users, targetUserID)
@@ -293,7 +319,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
}
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -315,12 +341,12 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
return nil, status.Errorf(status.NotFound, "targetUser not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
}
@@ -337,13 +363,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
}
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
return pat, nil
}
// DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -357,12 +383,12 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
return status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if targetUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
}
@@ -381,7 +407,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
}
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
delete(targetUser.PATs, tokenID)
@@ -393,7 +419,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
}
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -407,12 +433,12 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
return nil, status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
}
@@ -425,7 +451,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
}
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -439,12 +465,12 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
return nil, status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
@@ -456,9 +482,9 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
return pats, nil
}
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) {
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -471,56 +497,102 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
return nil, err
}
initiatorUser, err := account.FindUser(initiatorUserID)
if err != nil {
return nil, err
}
if !initiatorUser.IsAdmin() || initiatorUser.IsBlocked() {
return nil, status.Errorf(status.PermissionDenied, "only admins are authorized to perform user update operations")
}
oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
}
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && update.Role != UserRoleAdmin {
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
}
// only auto groups, revoked status, and name can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
}
oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(status.NotFound, "update not found")
}
// only auto groups, revoked status, and name can be updated for now
newUser := oldUser.Copy()
newUser.AutoGroups = update.AutoGroups
newUser.Role = update.Role
account.Users[newUser.Id] = newUser
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
return nil, err
}
var peerIDs []string
for _, peer := range blockedPeers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return nil, err
}
}
am.peersUpdateManager.CloseChannels(peerIDs)
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
return nil, err
}
}
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
defer func() {
// store activity logs
if oldUser.Role != newUser.Role {
am.storeEvent(userID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
}
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
if update.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
}
}
}
}()
if !isNil(am.idpManager) && !newUser.IsServiceUser {
@@ -531,9 +603,9 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
if userData == nil {
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
}
return newUser.toUserInfo(userData)
return newUser.ToUserInfo(userData)
}
return newUser.toUserInfo(nil)
return newUser.ToUserInfo(nil)
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
@@ -573,26 +645,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
return account, nil
}
// GetAccountByUserID returns an existing account for a given user id
func (am *DefaultAccountManager) GetAccountByUserID(userID string) (*Account, error) {
return am.Store.GetAccountByUser(userID)
}
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
func (am *DefaultAccountManager) IsUserAdmin(userID string) (bool, error) {
account, err := am.GetAccountByUserID(userID)
if err != nil {
return false, fmt.Errorf("get account: %v", err)
}
user, ok := account.Users[userID]
if !ok {
return false, status.Errorf(status.NotFound, "user not found")
}
return user.Role == UserRoleAdmin, nil
}
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
@@ -629,7 +681,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
// if user is not an admin then show only current user and do not show other users
continue
}
info, err := accountUser.toUserInfo(nil)
info, err := accountUser.ToUserInfo(nil)
if err != nil {
return nil, err
}
@@ -646,7 +698,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
var info *UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.toUserInfo(queriedUser)
info, err = localUser.ToUserInfo(queriedUser)
if err != nil {
return nil, err
}

View File

@@ -8,8 +8,10 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
const (
@@ -264,6 +266,7 @@ func TestUser_Copy(t *testing.T) {
LastUsed: time.Now(),
},
},
Blocked: false,
}
err := validateStruct(user)
@@ -287,7 +290,7 @@ func validateStruct(s interface{}) (err error) {
field := structVal.Field(i)
fieldName := structType.Field(i).Name
isSet := field.IsValid() && !field.IsZero()
isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool")
if !isSet {
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
@@ -439,7 +442,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
}
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
@@ -453,38 +456,27 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
eventStore: &activity.InMemoryEventStore{},
}
ok, err := am.IsUserAdmin(mockUserID)
claims := jwtclaims.AuthorizationClaims{
UserId: mockUserID,
}
user, err := am.GetUser(claims)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.True(t, ok)
assert.Equal(t, mockUserID, user.Id)
assert.True(t, user.IsAdmin())
assert.False(t, user.IsBlocked())
}
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
Role: "user",
}
func TestUser_IsAdmin(t *testing.T) {
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
user := NewAdminUser(mockUserID)
assert.True(t, user.IsAdmin())
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
ok, err := am.IsUserAdmin(mockUserID)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.False(t, ok)
user = NewRegularUser(mockUserID)
assert.False(t, user.IsAdmin())
}
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
@@ -541,3 +533,103 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
assert.Equal(t, 1, len(users))
assert.Equal(t, mockServiceUserID, users[0].ID)
}
func TestDefaultAccountManager_SaveUser(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
regularUserID := "regularUser"
tt := []struct {
name string
adminInitiator bool
update *User
expectedErr bool
}{
{
name: "Should_Fail_To_Update_Admin_Role",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleUser,
Blocked: false,
},
}, {
name: "Should_Fail_When_Admin_Blocks_Themselves",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Fail_To_Update_Non_Existing_User",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
expectedErr: true,
adminInitiator: false,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Update_User",
expectedErr: false,
adminInitiator: true,
update: &User{
Id: regularUserID,
Role: UserRoleAdmin,
Blocked: true,
},
},
}
for _, tc := range tt {
// create an account and an admin user
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.io")
if err != nil {
t.Fatal(err)
}
// create a regular user
account.Users[regularUserID] = NewRegularUser(regularUserID)
err = manager.Store.SaveAccount(account)
if err != nil {
t.Fatal(err)
}
initiatorID := userID
if !tc.adminInitiator {
initiatorID = regularUserID
}
updated, err := manager.SaveUser(account.Id, initiatorID, tc.update)
if tc.expectedErr {
require.Errorf(t, err, "expecting SaveUser to throw an error")
} else {
require.NoError(t, err, "expecting SaveUser not to throw an error")
assert.NotNil(t, updated)
assert.Equal(t, string(tc.update.Role), updated.Role)
assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked)
}
}
}

View File

@@ -36,6 +36,7 @@ download_release_binary() {
echo "Installing $1 from $DOWNLOAD_URL"
cd /tmp && curl -LO "$DOWNLOAD_URL"
if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then
INSTALL_DIR="/Applications/NetBird UI.app"
@@ -43,8 +44,9 @@ download_release_binary() {
unzip -q -o "$BINARY_NAME"
mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR"
else
sudo mkdir -p "$INSTALL_DIR"
tar -xzvf "$BINARY_NAME"
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR"
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/"
fi
}

11
sharedsock/filter.go Normal file
View File

@@ -0,0 +1,11 @@
package sharedsock
import "golang.org/x/net/bpf"
const magicCookie uint32 = 0x2112A442
// BPFFilter is a generic filter that provides ipv4 and ipv6 BPF instructions
type BPFFilter interface {
// GetInstructions returns raw BPF instructions for ipv4 and ipv6
GetInstructions(port uint32) (ipv4 []bpf.RawInstruction, ipv6 []bpf.RawInstruction, err error)
}

View File

@@ -0,0 +1,47 @@
package sharedsock
import "golang.org/x/net/bpf"
// IncomingSTUNFilter implements BPFFilter and filters out anything but incoming STUN packets to a specified destination port.
// Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard).
type IncomingSTUNFilter struct {
}
// NewIncomingSTUNFilter creates an instance of a IncomingSTUNFilter
func NewIncomingSTUNFilter() BPFFilter {
return &IncomingSTUNFilter{}
}
// GetInstructions returns raw BPF instructions for ipv4 and ipv6 that filter out anything but STUN packets
func (filter *IncomingSTUNFilter) GetInstructions(dstPort uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) {
raw4, err = rawInstructions(22, 32, dstPort)
if err != nil {
return nil, nil, err
}
raw6, err = rawInstructions(2, 12, dstPort)
if err != nil {
return nil, nil, err
}
return raw4, raw6, nil
}
func rawInstructions(dstPortOff, cookieOff, dstPort uint32) ([]bpf.RawInstruction, error) {
// UDP raw socket for ipv4 receives the rcvdPacket with IP headers
// UDP raw socket for ipv6 receives the rcvdPacket with UDP headers
instructions := []bpf.Instruction{
// Load the destination port from the UDP header (offset 22 for ipv4 and 2 for ipv6)
bpf.LoadAbsolute{Off: dstPortOff, Size: 2},
// Check if the destination port is equal to the specified `dstPort`. If not, skip the next 3 instructions.
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: dstPort, SkipTrue: 3},
// Load the 4-byte value (magic cookie) from the UDP payload (offset 32 for ipv4 and 12 for ipv6)
bpf.LoadAbsolute{Off: cookieOff, Size: 4},
// Check if the loaded value is equal to the `magicCookie`. If not, skip the next instruction.
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: magicCookie, SkipTrue: 1},
// If both the dstPort and the magic cookie match, return a positive value (0xffffffff)
bpf.RetConstant{Val: 0xffffffff},
// If either the dstPort or the magic cookie doesn't match, return 0
bpf.RetConstant{Val: 0},
}
return bpf.Assemble(instructions)
}

View File

@@ -0,0 +1,8 @@
//go:build !linux
package sharedsock
// NewIncomingSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux
func NewIncomingSTUNFilter() BPFFilter {
return nil
}

328
sharedsock/sock_linux.go Normal file
View File

@@ -0,0 +1,328 @@
//go:build linux && !android
// Inspired by
// Jason Donenfeld (https://git.zx2c4.com/wireguard-tools/tree/contrib/nat-hole-punching/nat-punch-client.c#n96)
// and @stv0g in https://github.com/stv0g/cunicu/tree/ebpf-poc/ebpf_poc
package sharedsock
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/routing"
"github.com/libp2p/go-netroute"
"github.com/mdlayher/socket"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"
)
// ErrSharedSockStopped indicates that shared socket has been stopped
var ErrSharedSockStopped = fmt.Errorf("shared socked stopped")
// SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered
// by BPF instructions (e.g., IncomingSTUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)).
// It is meant to be used when sharing a port with some other process.
type SharedSocket struct {
ctx context.Context
conn4 *socket.Conn
conn6 *socket.Conn
port int
routerMux sync.RWMutex
router routing.Router
packetDemux chan rcvdPacket
cancel context.CancelFunc
}
type rcvdPacket struct {
n int
addr unix.Sockaddr
buf []byte
err error
}
type receiver func(ctx context.Context, p []byte, flags int) (int, unix.Sockaddr, error)
var writeSerializerOptions = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
// Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines
func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
var err error
ctx, cancel := context.WithCancel(context.Background())
rawSock := &SharedSocket{
ctx: ctx,
cancel: cancel,
port: port,
packetDemux: make(chan rcvdPacket),
}
rawSock.router, err = netroute.New()
if err != nil {
return nil, fmt.Errorf("failed to create router: %rawSock", err)
}
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
if err != nil {
return nil, fmt.Errorf("socket.Socket for ipv4 failed with: %rawSock", err)
}
rawSock.conn6, err = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
if err != nil {
log.Errorf("socket.Socket for ipv6 failed with: %rawSock", err)
}
ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port))
if err != nil {
_ = rawSock.Close()
return nil, fmt.Errorf("getBPFInstructions failed with: %rawSock", err)
}
err = rawSock.conn4.SetBPF(ipv4Instructions)
if err != nil {
_ = rawSock.Close()
return nil, fmt.Errorf("socket4.SetBPF failed with: %rawSock", err)
}
if rawSock.conn6 != nil {
err = rawSock.conn6.SetBPF(ipv6Instructions)
if err != nil {
_ = rawSock.Close()
return nil, fmt.Errorf("socket6.SetBPF failed with: %rawSock", err)
}
}
go rawSock.read(rawSock.conn4.Recvfrom)
if rawSock.conn6 != nil {
go rawSock.read(rawSock.conn6.Recvfrom)
}
go rawSock.updateRouter()
return rawSock, nil
}
// updateRouter updates the listener routing table client
// this is needed to avoid outdated information across different client networks
func (s *SharedSocket) updateRouter() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
router, err := netroute.New()
if err != nil {
log.Errorf("failed to create and update packet router for stunListener: %s", err)
continue
}
s.routerMux.Lock()
s.router = router
s.routerMux.Unlock()
}
}
}
// LocalAddr returns an IPv4 address using the supplied port
func (s *SharedSocket) LocalAddr() net.Addr {
// todo check impact on ipv6 discovery
return &net.UDPAddr{
IP: net.IPv4zero,
Port: s.port,
}
}
// SetDeadline sets both the read and write deadlines associated with the ipv4 and ipv6 Conn sockets
func (s *SharedSocket) SetDeadline(t time.Time) error {
err := s.conn4.SetDeadline(t)
if err != nil {
return fmt.Errorf("s.conn4.SetDeadline error: %s", err)
}
if s.conn6 == nil {
return nil
}
err = s.conn6.SetDeadline(t)
if err != nil {
return fmt.Errorf("s.conn6.SetDeadline error: %s", err)
}
return nil
}
// SetReadDeadline sets the read deadline associated with the ipv4 and ipv6 Conn sockets
func (s *SharedSocket) SetReadDeadline(t time.Time) error {
err := s.conn4.SetReadDeadline(t)
if err != nil {
return fmt.Errorf("s.conn4.SetReadDeadline error: %s", err)
}
if s.conn6 == nil {
return nil
}
err = s.conn6.SetReadDeadline(t)
if err != nil {
return fmt.Errorf("s.conn6.SetReadDeadline error: %s", err)
}
return nil
}
// SetWriteDeadline sets the write deadline associated with the ipv4 and ipv6 Conn sockets
func (s *SharedSocket) SetWriteDeadline(t time.Time) error {
err := s.conn4.SetWriteDeadline(t)
if err != nil {
return fmt.Errorf("s.conn4.SetWriteDeadline error: %s", err)
}
if s.conn6 == nil {
return nil
}
err = s.conn6.SetWriteDeadline(t)
if err != nil {
return fmt.Errorf("s.conn6.SetWriteDeadline error: %s", err)
}
return nil
}
// Close closes the underlying ipv4 and ipv6 conn sockets
func (s *SharedSocket) Close() error {
s.cancel()
errGrp := errgroup.Group{}
if s.conn4 != nil {
errGrp.Go(s.conn4.Close)
}
if s.conn6 != nil {
errGrp.Go(s.conn6.Close)
}
return errGrp.Wait()
}
// read start a read loop for a specific receiver and sends the packet to the packetDemux channel
func (s *SharedSocket) read(receiver receiver) {
for {
buf := make([]byte, 1500)
n, addr, err := receiver(s.ctx, buf, 0)
select {
case <-s.ctx.Done():
return
case s.packetDemux <- rcvdPacket{n, addr, buf[:n], err}:
}
}
}
// ReadFrom reads packets received in the packetDemux channel
func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
var pkt rcvdPacket
select {
case <-s.ctx.Done():
return -1, nil, ErrSharedSockStopped
case pkt = <-s.packetDemux:
}
if pkt.err != nil {
return -1, nil, pkt.err
}
var ip4layer layers.IPv4
var udp layers.UDP
var payload gopacket.Payload
var parser *gopacket.DecodingLayerParser
var ip net.IP
if sa, isIPv4 := pkt.addr.(*unix.SockaddrInet4); isIPv4 {
ip = sa.Addr[:]
parser = gopacket.NewDecodingLayerParser(layers.LayerTypeIPv4, &ip4layer, &udp, &payload)
} else if sa, isIPv6 := pkt.addr.(*unix.SockaddrInet6); isIPv6 {
ip = sa.Addr[:]
parser = gopacket.NewDecodingLayerParser(layers.LayerTypeUDP, &udp, &payload)
} else {
return -1, nil, fmt.Errorf("received invalid address family")
}
decodedLayers := make([]gopacket.LayerType, 0, 3)
err = parser.DecodeLayers(pkt.buf[:], &decodedLayers)
if err != nil {
return 0, nil, err
}
remoteAddr := &net.UDPAddr{
IP: ip,
Port: int(udp.SrcPort),
}
copy(b, payload)
return int(udp.Length), remoteAddr, nil
}
// WriteTo builds a UDP packet and writes it using the specific IP version writter
func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
rUDPAddr, ok := rAddr.(*net.UDPAddr)
if !ok {
return -1, fmt.Errorf("invalid address type")
}
buffer := gopacket.NewSerializeBuffer()
payload := gopacket.Payload(buf)
udp := &layers.UDP{
SrcPort: layers.UDPPort(s.port),
DstPort: layers.UDPPort(rUDPAddr.Port),
}
s.routerMux.RLock()
defer s.routerMux.RUnlock()
_, _, src, err := s.router.Route(rUDPAddr.IP)
if err != nil {
return 0, fmt.Errorf("got an error while checking route, err: %s", err)
}
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil {
return -1, fmt.Errorf("failed to set network layer for checksum: %w", err)
}
if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil {
return -1, fmt.Errorf("failed serialize rcvdPacket: %s", err)
}
bufser := buffer.Bytes()
return 0, conn.Sendto(context.TODO(), bufser, 0, rSockAddr)
}
// getWriterObjects returns the specific IP version objects that are used to build a packet and send it using the raw socket
func (s *SharedSocket) getWriterObjects(src, dest net.IP) (sa unix.Sockaddr, conn *socket.Conn, layer gopacket.NetworkLayer) {
if dest.To4() == nil {
sa = &unix.SockaddrInet6{}
copy(sa.(*unix.SockaddrInet6).Addr[:], dest.To16())
conn = s.conn6
layer = &layers.IPv6{
SrcIP: src,
DstIP: dest,
}
} else {
sa = &unix.SockaddrInet4{}
copy(sa.(*unix.SockaddrInet4).Addr[:], dest.To4())
conn = s.conn4
layer = &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: src,
DstIP: dest,
}
}
return sa, conn, layer
}

View File

@@ -0,0 +1,162 @@
package sharedsock
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"sync"
"testing"
"time"
"github.com/pion/stun"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func TestShouldReadSTUNOnReadFrom(t *testing.T) {
// create raw socket on a port
testingPort := 51821
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second))
require.NoError(t, err, "unable to set deadline, error: %s", err)
wg := sync.WaitGroup{}
wg.Add(1)
// when reading from the raw socket
buf := make([]byte, 1500)
rcvMSG := &stun.Message{
Raw: buf,
}
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
go func() {
select {
case <-ctx.Done():
return
default:
_, _, err := rawSock.ReadFrom(buf)
if err != nil {
log.Errorf("error while reading packet %s", err)
return
}
err = rcvMSG.Decode()
if err != nil {
log.Warnf("error while parsing STUN message. The packet doesn't seem to be a STUN packet: %s", err)
return
}
wg.Done()
}
}()
// and sending STUN packet to the shared port, the packet has to be handled
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{Port: 12345, IP: net.ParseIP("127.0.0.1")})
require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
defer udpListener.Close()
stunMSG, err := stun.Build(stun.NewType(stun.MethodBinding, stun.ClassRequest), stun.TransactionID,
stun.Fingerprint,
)
require.NoError(t, err, "unable to build stun msg, error: %s", err)
_, err = udpListener.WriteTo(stunMSG.Raw, net.UDPAddrFromAddrPort(netip.MustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", testingPort))))
require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)
// the packet has to be handled and be a STUN packet
wg.Wait()
require.EqualValues(t, stunMSG.TransactionID, rcvMSG.TransactionID, "transaction id values did't match")
}
func TestShouldNotReadNonSTUNPackets(t *testing.T) {
testingPort := 39439
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
defer rawSock.Close()
buf := make([]byte, 1500)
err = rawSock.SetReadDeadline(time.Now().Add(time.Second))
require.NoError(t, err, "unable to set deadline, error: %s", err)
errGrp := errgroup.Group{}
errGrp.Go(func() error {
_, _, err := rawSock.ReadFrom(buf)
return err
})
nonStun := []byte("netbird")
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0, IP: net.ParseIP("127.0.0.1")})
require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
defer udpListener.Close()
remote := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", testingPort)))
_, err = udpListener.WriteTo(nonStun, remote)
require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)
err = errGrp.Wait()
require.Error(t, err, "should receive an error")
if !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("error should be I/O timeout, got: %s", err)
}
}
func TestWriteTo(t *testing.T) {
udpListener, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 0, IP: net.ParseIP("127.0.0.1")})
require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
defer udpListener.Close()
testingPort := 39440
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
defer rawSock.Close()
buf := make([]byte, 1500)
err = udpListener.SetReadDeadline(time.Now().Add(3 * time.Second))
require.NoError(t, err, "unable to set deadline, error: %s", err)
errGrp := errgroup.Group{}
var remoteAdr net.Addr
var rcvBytes int
errGrp.Go(func() error {
n, a, err := udpListener.ReadFrom(buf)
remoteAdr = a
rcvBytes = n
return err
})
msg := []byte("netbird")
_, err = rawSock.WriteTo(msg, udpListener.LocalAddr())
require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)
err = errGrp.Wait()
require.NoError(t, err, "received an error while reading the packet, error: %s", err)
require.EqualValues(t, string(msg), string(buf[:rcvBytes]), "received message should match")
udpRcv, ok := remoteAdr.(*net.UDPAddr)
require.True(t, ok, "udp address conversion didn't work")
require.EqualValues(t, testingPort, udpRcv.Port, "received address port didn't match")
}
func TestSharedSocket_Close(t *testing.T) {
rawSock, err := Listen(39440, NewIncomingSTUNFilter())
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
errGrp := errgroup.Group{}
errGrp.Go(func() error {
buf := make([]byte, 1500)
_, _, err := rawSock.ReadFrom(buf)
return err
})
_ = rawSock.Close()
err = errGrp.Wait()
if err != ErrSharedSockStopped {
t.Errorf("invalid error response: %s", err)
}
}

View File

@@ -0,0 +1,14 @@
//go:build !linux || android
package sharedsock
import (
"fmt"
"net"
"runtime"
)
// Listen is not supported on other platforms
func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
return nil, fmt.Errorf(fmt.Sprintf("Not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS))
}

View File

@@ -350,7 +350,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
} else if err != nil {
return err
}
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg)
if err != nil {