Compare commits

...

12 Commits

Author SHA1 Message Date
Zoltan Papp
7b43d7e8ef Delete proxy package
The proxy package contains duplicated wg
configuration logic
2023-05-12 16:08:35 +02:00
Zoltan Papp
dcc83c8741 Replace the "recv a new msg from peer" debug log
to trace level
2023-05-12 11:07:41 +02:00
Zoltan Papp
d56669ec2e Remove unused dummy proxy 2023-05-12 11:05:59 +02:00
Misha Bragin
e3d2b6a408 Block user through HTTP API (#846)
The new functionality allows blocking a user in the Management service.
Blocked users lose access to the Dashboard, aren't able to modify the network map,
and all of their connected devices disconnect and are set to the "login expired" state.

Technically all above was achieved with the updated PUT /api/users endpoint,
that was extended with the is_blocked field.
2023-05-11 18:09:36 +02:00
Zoltan Papp
9f758b2015 Fix preshared key command line arg handling (#850) 2023-05-11 18:09:06 +02:00
Bethuel
2c50d7af1e Automatically load IdP OIDC configuration (#847) 2023-05-11 15:14:00 +02:00
pascal-fischer
e4c28f64fa Fix user cache lookup filtering for service users (#849) 2023-05-10 19:27:17 +02:00
Maycon Santos
6f2c4078ef Fix macOS installer script (#844)
Create /usr/local/bin/ folder before installation
2023-05-09 16:22:02 +02:00
Bethuel
f4ec1699ca Add Zitadel IdP (#833)
Added intergration with Zitadel management API.

Use the steps in zitadel.md for configuration.
2023-05-05 19:27:28 +02:00
Bethuel
fea53b2f0f Fix incomplete verification URI issue in device auth flow (#838)
Adds functionality to support Identity Provider (IdP) managers 
that do not support a complete verification URI in the 
device authentication flow. 
In cases where the verification_uri_complete field is empty,
the user will be prompted with their user_code, 
and the verification_uri  field will be used as a fallback
2023-05-05 12:43:04 +02:00
Zoltan Papp
60e6d0890a Fix sharedsock build on android (#837) 2023-05-05 10:55:23 +02:00
Misha Bragin
cb12e2da21 Correct sharedsock BPF fields (#835) 2023-05-04 12:28:32 +02:00
39 changed files with 1859 additions and 522 deletions

View File

@@ -3,6 +3,7 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
@@ -45,12 +46,16 @@ var loginCmd = &cobra.Command{
return err return err
} }
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ ic := internal.ConfigInput{
ManagementURL: managementURL, ManagementURL: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
ConfigPath: configPath, ConfigPath: configPath,
PreSharedKey: &preSharedKey, }
}) if preSharedKey != "" {
ic.PreSharedKey = &preSharedKey
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
} }
@@ -106,7 +111,7 @@ var loginCmd = &cobra.Command{
} }
if loginResp.NeedsSSOLogin { if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) _, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil { if err != nil {
@@ -185,7 +190,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return nil, fmt.Errorf("getting a request device code failed: %v", err) return nil, fmt.Errorf("getting a request device code failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) waitTimeout := time.Duration(flowInfo.ExpiresIn)
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second) waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
@@ -199,11 +204,16 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return &tokenInfo, nil return &tokenInfo, nil
} }
func openURL(cmd *cobra.Command, verificationURIComplete string) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string
if !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
err := open.Run(verificationURIComplete) err := open.Run(verificationURIComplete)
cmd.Printf("Please do the SSO login in your browser. \n" + cmd.Printf("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" + "If your browser didn't open automatically, use this URL to log in:\n\n" +
" " + verificationURIComplete + " \n\n") " " + verificationURIComplete + " " + codeMsg + " \n\n")
if err != nil { if err != nil {
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n") cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
} }

View File

@@ -78,14 +78,18 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ ic := internal.ConfigInput{
ManagementURL: managementURL, ManagementURL: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
ConfigPath: configPath, ConfigPath: configPath,
PreSharedKey: &preSharedKey,
NATExternalIPs: natExternalIPs, NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
}) }
if preSharedKey != "" {
ic.PreSharedKey = &preSharedKey
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
} }
@@ -172,7 +176,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if loginResp.NeedsSSOLogin { if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) _, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil { if err != nil {

View File

@@ -19,7 +19,6 @@ import (
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -210,7 +209,7 @@ func (e *Engine) Start() error {
e.udpMux = udpMux e.udpMux = udpMux
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String()) log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
} else { } else {
rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewSTUNFilter()) rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewIncomingSTUNFilter())
if err != nil { if err != nil {
return err return err
} }
@@ -247,7 +246,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok { if peerConn, ok := e.peerConns[peerPubKey]; ok {
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") { if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p) modified = append(modified, p)
continue continue
} }
@@ -758,9 +757,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
// we might have received new STUN and TURN servers meanwhile, so update them // we might have received new STUN and TURN servers meanwhile, so update them
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
conf := conn.GetConf() conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
conf.StunTurn = append(e.STUNs, e.TURNs...)
conn.UpdateConf(conf)
e.syncMsgMux.Unlock() e.syncMsgMux.Unlock()
err := conn.Open() err := conn.Open()
@@ -789,9 +786,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
stunTurn = append(stunTurn, e.STUNs...) stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...) stunTurn = append(stunTurn, e.TURNs...)
proxyConfig := proxy.Config{ wgConfig := peer.WgConfig{
RemoteKey: pubKey, RemoteKey: pubKey,
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort), WgListenPort: e.config.WgPort,
WgInterface: e.wgInterface, WgInterface: e.wgInterface,
AllowedIps: allowedIPs, AllowedIps: allowedIPs,
PreSharedKey: e.config.PreSharedKey, PreSharedKey: e.config.PreSharedKey,
@@ -808,7 +805,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
Timeout: timeout, Timeout: timeout,
UDPMux: e.udpMux.UDPMuxDefault, UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux, UDPMuxSrflx: e.udpMux,
ProxyConfig: proxyConfig, WgConfig: wgConfig,
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(), UserspaceBind: e.wgInterface.IsUserspaceBind(),

View File

@@ -148,6 +148,11 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err) return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
} }
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
if deviceCode.VerificationURIComplete == "" {
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
return deviceCode, err return deviceCode, err
} }

View File

@@ -12,8 +12,8 @@ import (
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
@@ -28,8 +28,18 @@ const (
iceKeepAliveDefault = 4 * time.Second iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second iceDisconnectedTimeoutDefault = 6 * time.Second
defaultWgKeepAlive = 25 * time.Second
) )
type WgConfig struct {
WgListenPort int
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
// ConnConfig is a peer Connection configuration // ConnConfig is a peer Connection configuration
type ConnConfig struct { type ConnConfig struct {
@@ -48,7 +58,7 @@ type ConnConfig struct {
Timeout time.Duration Timeout time.Duration
ProxyConfig proxy.Config WgConfig WgConfig
UDPMux ice.UDPMux UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux UDPMuxSrflx ice.UniversalUDPMux
@@ -103,7 +113,7 @@ type Conn struct {
statusRecorder *Status statusRecorder *Status
proxy proxy.Proxy proxy *WireGuardProxy
remoteModeCh chan ModeMessage remoteModeCh chan ModeMessage
meta meta meta meta
@@ -127,9 +137,14 @@ func (conn *Conn) GetConf() ConnConfig {
return conn.config return conn.config
} }
// UpdateConf updates the connection config // WgConfig returns the WireGuard config
func (conn *Conn) UpdateConf(conf ConnConfig) { func (conn *Conn) WgConfig() WgConfig {
conn.config = conf return conn.config.WgConfig
}
// UpdateStunTurn update the turn and stun addresses
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
conn.config.StunTurn = turnStun
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
@@ -240,12 +255,12 @@ func readICEAgentConfigProperties() (time.Duration, time.Duration) {
func (conn *Conn) Open() error { func (conn *Conn) Open() error {
log.Debugf("trying to connect to peer %s", conn.config.Key) log.Debugf("trying to connect to peer %s", conn.config.Key)
peerState := State{PubKey: conn.config.Key} peerState := State{
PubKey: conn.config.Key,
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0] IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
peerState.ConnStatusUpdate = time.Now() ConnStatusUpdate: time.Now(),
peerState.ConnStatus = conn.status ConnStatus: conn.status,
}
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err) log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
@@ -300,10 +315,11 @@ func (conn *Conn) Open() error {
defer conn.notifyDisconnected() defer conn.notifyDisconnected()
conn.mu.Unlock() conn.mu.Unlock()
peerState = State{PubKey: conn.config.Key} peerState = State{
PubKey: conn.config.Key,
peerState.ConnStatus = conn.status ConnStatus: conn.status,
peerState.ConnStatusUpdate = time.Now() ConnStatusUpdate: time.Now(),
}
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err) log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
@@ -334,19 +350,12 @@ func (conn *Conn) Open() error {
remoteWgPort = remoteOfferAnswer.WgListenPort remoteWgPort = remoteOfferAnswer.WgListenPort
} }
// the ice connection has been established successfully so we are ready to start the proxy // the ice connection has been established successfully so we are ready to start the proxy
err = conn.startProxy(remoteConn, remoteWgPort) remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
if err != nil { if err != nil {
return err return err
} }
if conn.proxy.Type() == proxy.TypeDirectNoProxy { log.Infof("connected to peer %s, proxy: %v, remote address: %s", conn.config.Key, conn.proxy != nil, remoteAddr.String())
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
} else {
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
}
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine) // wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select { select {
@@ -363,54 +372,58 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay return candidate.Type() == ice.CandidateTypeRelay
} }
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error { func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
var pair *ice.CandidatePair
pair, err := conn.agent.GetSelectedCandidatePair() pair, err := conn.agent.GetSelectedCandidatePair()
if err != nil { if err != nil {
return err return nil, err
} }
peerState := State{PubKey: conn.config.Key} var endpoint net.Addr
p := conn.getProxy(pair, remoteWgPort) if isRelayCandidate(pair.Local) {
conn.proxy = p conn.proxy = NewWireGuardProxy(conn.config.WgConfig.WgListenPort, conn.config.WgConfig.RemoteKey, remoteConn)
err = p.Start(remoteConn) endpoint, err = conn.proxy.Start()
if err != nil {
conn.proxy = nil
return nil, err
}
} else {
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
endpoint = remoteConn.RemoteAddr()
}
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpoint, conn.config.WgConfig.PreSharedKey)
if err != nil { if err != nil {
return err if conn.proxy != nil {
_ = conn.proxy.Close()
}
return nil, err
} }
conn.status = StatusConnected conn.status = StatusConnected
peerState.ConnStatus = conn.status peerState := State{
peerState.ConnStatusUpdate = time.Now() PubKey: conn.config.Key,
peerState.LocalIceCandidateType = pair.Local.Type().String() ConnStatus: conn.status,
peerState.RemoteIceCandidateType = pair.Remote.Type().String() ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
Direct: conn.proxy == nil,
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true peerState.Relayed = true
} }
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err) log.Warnf("unable to save peer's state, got error: %v", err)
} }
return nil return endpoint, nil
}
// todo rename this method and the proxy package to something more appropriate
func (conn *Conn) getProxy(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
if isRelayCandidate(pair.Local) {
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
}
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
return proxy.NewNoProxy(conn.config.ProxyConfig)
} }
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -439,20 +452,22 @@ func (conn *Conn) cleanup() error {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
var err1, err2, err3 error
if conn.agent != nil { if conn.agent != nil {
err := conn.agent.Close() err1 = conn.agent.Close()
if err != nil { if err1 == nil {
return err conn.agent = nil
} }
conn.agent = nil
} }
// todo: is it problem if we try to remove a peer what is never existed?
err2 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if conn.proxy != nil { if conn.proxy != nil {
err := conn.proxy.Close() err3 = conn.proxy.Close()
if err != nil { if err3 != nil {
return err conn.proxy = nil
} }
conn.proxy = nil
} }
if conn.notifyDisconnected != nil { if conn.notifyDisconnected != nil {
@@ -462,10 +477,11 @@ func (conn *Conn) cleanup() error {
conn.status = StatusDisconnected conn.status = StatusDisconnected
peerState := State{PubKey: conn.config.Key} peerState := State{
peerState.ConnStatus = conn.status PubKey: conn.config.Key,
peerState.ConnStatusUpdate = time.Now() ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available. // pretty common error because by that time Engine can already remove the peer and status won't be available.
@@ -474,8 +490,13 @@ func (conn *Conn) cleanup() error {
} }
log.Debugf("cleaned up connection to peer %s", conn.config.Key) log.Debugf("cleaned up connection to peer %s", conn.config.Key)
if err1 != nil {
return nil return err1
}
if err2 != nil {
return err2
}
return err3
} }
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer // SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer

View File

@@ -1,7 +1,8 @@
package proxy package peer
import ( import (
"context" "context"
"fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net" "net"
) )
@@ -11,67 +12,45 @@ type WireGuardProxy struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
config Config wgListenPort int
remoteKey string
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
} }
func NewWireGuardProxy(config Config) *WireGuardProxy { func NewWireGuardProxy(wgListenPort int, remoteKey string, remoteConn net.Conn) *WireGuardProxy {
p := &WireGuardProxy{config: config} p := &WireGuardProxy{
wgListenPort: wgListenPort,
remoteKey: remoteKey,
remoteConn: remoteConn,
}
p.ctx, p.cancel = context.WithCancel(context.Background()) p.ctx, p.cancel = context.WithCancel(context.Background())
return p return p
} }
func (p *WireGuardProxy) updateEndpoint() error { func (p *WireGuardProxy) Start() (net.Addr, error) {
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) lConn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", p.wgListenPort))
if err != nil {
return err
}
// add local proxy connection as a Wireguard peer
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
udpAddr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
p.remoteConn = remoteConn
var err error
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
if err != nil { if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err) log.Errorf("failed dialing to local Wireguard port %s", err)
return err return nil, err
}
err = p.updateEndpoint()
if err != nil {
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
return err
} }
p.localConn = lConn
go p.proxyToRemote() go p.proxyToRemote()
go p.proxyToLocal() go p.proxyToLocal()
return nil return lConn.LocalAddr(), nil
} }
func (p *WireGuardProxy) Close() error { func (p *WireGuardProxy) Close() error {
p.cancel() p.cancel()
if c := p.localConn; c != nil { if p.localConn != nil {
err := p.localConn.Close() err := p.localConn.Close()
if err != nil { if err != nil {
return err return err
} }
} }
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil return nil
} }
@@ -83,7 +62,7 @@ func (p *WireGuardProxy) proxyToRemote() {
for { for {
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey) log.Debugf("stopped proxying to remote peer %s due to closed connection", p.remoteKey)
return return
default: default:
n, err := p.localConn.Read(buf) n, err := p.localConn.Read(buf)
@@ -107,7 +86,7 @@ func (p *WireGuardProxy) proxyToLocal() {
for { for {
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey) log.Debugf("stopped proxying from remote peer %s due to closed connection", p.remoteKey)
return return
default: default:
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
@@ -122,7 +101,3 @@ func (p *WireGuardProxy) proxyToLocal() {
} }
} }
} }
func (p *WireGuardProxy) Type() Type {
return TypeWireGuard
}

View File

@@ -1,72 +0,0 @@
package proxy
import (
"context"
log "github.com/sirupsen/logrus"
"net"
"time"
)
// DummyProxy just sends pings to the RemoteKey peer and reads responses
type DummyProxy struct {
conn net.Conn
remote string
ctx context.Context
cancel context.CancelFunc
}
func NewDummyProxy(remote string) *DummyProxy {
p := &DummyProxy{remote: remote}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *DummyProxy) Close() error {
p.cancel()
return nil
}
func (p *DummyProxy) Start(remoteConn net.Conn) error {
p.conn = remoteConn
go func() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
_, err := p.conn.Read(buf)
if err != nil {
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
return
}
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
}
}
}()
go func() {
for {
select {
case <-p.ctx.Done():
return
default:
_, err := p.conn.Write([]byte("hello"))
//log.Debugf("sent ping to %s", p.remote)
if err != nil {
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
return
}
time.Sleep(5 * time.Second)
}
}
}()
return nil
}
func (p *DummyProxy) Type() Type {
return TypeDummy
}

View File

@@ -1,42 +0,0 @@
package proxy
import (
log "github.com/sirupsen/logrus"
"net"
)
// NoProxy is used just to configure WireGuard without any local proxy in between.
// Used when the WireGuard interface is userspace and uses bind.ICEBind
type NoProxy struct {
config Config
}
// NewNoProxy creates a new NoProxy with a provided config
func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config}
}
// Close removes peer from the WireGuard interface
func (p *NoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote address
func (p *NoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
}
func (p *NoProxy) Type() Type {
return TypeNoProxy
}

View File

@@ -1,35 +0,0 @@
package proxy
import (
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"io"
"net"
"time"
)
const DefaultWgKeepAlive = 25 * time.Second
type Type string
const (
TypeDirectNoProxy Type = "DirectNoProxy"
TypeWireGuard Type = "WireGuard"
TypeDummy Type = "Dummy"
TypeNoProxy Type = "NoProxy"
)
type Config struct {
WgListenAddr string
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
type Proxy interface {
io.Closer
// Start creates a local remoteConn and starts proxying data from/to remoteConn
Start(remoteConn net.Conn) error
Type() Type
}

View File

@@ -77,12 +77,17 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional // Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint net.Addr, preSharedKey *wgtypes.Key) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint) rAddr, err := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) if err != nil {
return err
}
log.Debugf("updating interface %s peer %s, endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, rAddr, preSharedKey)
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface

View File

@@ -396,6 +396,12 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
} }
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint) log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.Infof("configuring IdpManagerConfig.OIDCConfig.Issuer with a new value %s,", oidcConfig.Issuer)
config.IdpManagerConfig.OIDCConfig.Issuer = strings.TrimRight(oidcConfig.Issuer, "/")
log.Infof("configuring IdpManagerConfig.OIDCConfig.TokenEndpoint with a new value %s,", oidcConfig.TokenEndpoint)
config.IdpManagerConfig.OIDCConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s", log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, config.HttpConfig.AuthIssuer) oidcConfig.Issuer, config.HttpConfig.AuthIssuer)
config.HttpConfig.AuthIssuer = oidcConfig.Issuer config.HttpConfig.AuthIssuer = oidcConfig.Issuer
@@ -441,7 +447,7 @@ type OIDCConfigResponse struct {
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
res, err := http.Get(oidcEndpoint) res, err := http.Get(oidcEndpoint)
if err != nil { if err != nil {
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err) return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
} }
defer func() { defer func() {

View File

@@ -49,16 +49,16 @@ type AccountManager interface {
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string) (*SetupKey, error) autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error) CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, executingUserID string, targetUserID string) error DeleteUser(accountID, initiatorUserID string, targetUserID string) error
ListSetupKeys(accountID, userID string) ([]*SetupKey, error) ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, userID string, update *User) (*UserInfo, error) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error MarkPATUsed(tokenID string) error
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error) GetPeerByKey(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error) GetPeers(accountID, userID string) ([]*Peer, error)
@@ -69,10 +69,10 @@ type AccountManager interface {
GetNetworkMap(peerID string) (*NetworkMap, error) GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error) GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(peerID string, sshKey string) error UpdatePeerSSHKey(peerID string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error) GetGroup(accountId, groupID string) (*Group, error)
@@ -179,6 +179,7 @@ type UserInfo struct {
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
Status string `json:"-"` Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"` IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
} }
// getRoutesToSync returns the enabled routes for the peer ID and the routes // getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -902,7 +903,9 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users)) users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users { for _, user := range account.Users {
users[user.Id] = struct{}{} if !user.IsServiceUser {
users[user.Id] = struct{}{}
}
} }
log.Debugf("looking up user %s of account %s in cache", userID, account.Id) log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id) userData, err := am.lookupCache(users, account.Id)

View File

@@ -91,6 +91,10 @@ const (
ServiceUserCreated ServiceUserCreated
// ServiceUserDeleted indicates that a user deleted a service user // ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted ServiceUserDeleted
// UserBlocked indicates that a user blocked another user
UserBlocked
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
) )
const ( const (
@@ -184,6 +188,10 @@ const (
ServiceUserCreatedMessage string = "Service user created" ServiceUserCreatedMessage string = "Service user created"
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity // ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
ServiceUserDeletedMessage string = "Service user deleted" ServiceUserDeletedMessage string = "Service user deleted"
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
UserBlockedMessage string = "User blocked"
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
UserUnblockedMessage string = "User unblocked"
) )
// Activity that triggered an Event // Activity that triggered an Event
@@ -282,6 +290,10 @@ func (a Activity) Message() string {
return ServiceUserCreatedMessage return ServiceUserCreatedMessage
case ServiceUserDeleted: case ServiceUserDeleted:
return ServiceUserDeletedMessage return ServiceUserDeletedMessage
case UserBlocked:
return UserBlockedMessage
case UserUnblocked:
return UserUnblockedMessage
default: default:
return "UNKNOWN_ACTIVITY" return "UNKNOWN_ACTIVITY"
} }
@@ -300,6 +312,10 @@ func (a Activity) StringCode() string {
return "user.join" return "user.join"
case UserInvited: case UserInvited:
return "user.invite" return "user.invite"
case UserBlocked:
return "user.block"
case UserUnblocked:
return "user.unblock"
case AccountCreated: case AccountCreated:
return "account.create" return "account.create"
case RuleAdded: case RuleAdded:

View File

@@ -65,7 +65,7 @@ components:
status: status:
description: User's status description: User's status
type: string type: string
enum: [ "active","invited","disabled" ] enum: [ "active","invited","blocked" ]
auto_groups: auto_groups:
description: Groups to auto-assign to peers registered by this user description: Groups to auto-assign to peers registered by this user
type: array type: array
@@ -79,6 +79,9 @@ components:
description: Is true if this user is a service user description: Is true if this user is a service user
type: boolean type: boolean
readOnly: true readOnly: true
is_blocked:
description: Is true if this user is blocked. Blocked users can't use the system
type: boolean
required: required:
- id - id
- email - email
@@ -86,6 +89,7 @@ components:
- role - role
- auto_groups - auto_groups
- status - status
- is_blocked
UserRequest: UserRequest:
type: object type: object
properties: properties:
@@ -97,9 +101,13 @@ components:
type: array type: array
items: items:
type: string type: string
is_blocked:
description: If set to true then user is blocked and can't use the system
type: boolean
required: required:
- role - role
- auto_groups - auto_groups
- is_blocked
UserCreateRequest: UserCreateRequest:
type: object type: object
properties: properties:
@@ -645,7 +653,7 @@ components:
description: The string code of the activity that occurred during the event description: The string code of the activity that occurred during the event
type: string type: string
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete", enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
"user.role.update", "user.role.update", "user.block", "user.unblock",
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse", "setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
"setupkey.group.delete", "setupkey.group.add", "setupkey.group.delete", "setupkey.group.add",
"rule.add", "rule.delete", "rule.update", "rule.add", "rule.delete", "rule.update",

View File

@@ -46,6 +46,7 @@ const (
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add" EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke" EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update" EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
EventActivityCodeUserBlock EventActivityCode = "user.block"
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add" EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete" EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
EventActivityCodeUserInvite EventActivityCode = "user.invite" EventActivityCodeUserInvite EventActivityCode = "user.invite"
@@ -53,6 +54,7 @@ const (
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add" EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete" EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update" EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
EventActivityCodeUserUnblock EventActivityCode = "user.unblock"
) )
// Defines values for NameserverNsType. // Defines values for NameserverNsType.
@@ -68,9 +70,9 @@ const (
// Defines values for UserStatus. // Defines values for UserStatus.
const ( const (
UserStatusActive UserStatus = "active" UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled" UserStatusBlocked UserStatus = "blocked"
UserStatusInvited UserStatus = "invited" UserStatusInvited UserStatus = "invited"
) )
// Account defines model for Account. // Account defines model for Account.
@@ -552,6 +554,9 @@ type User struct {
// Id User ID // Id User ID
Id string `json:"id"` Id string `json:"id"`
// IsBlocked Is true if this user is blocked. Blocked users can't use the system
IsBlocked bool `json:"is_blocked"`
// IsCurrent Is true if authenticated user is the same as this user // IsCurrent Is true if authenticated user is the same as this user
IsCurrent *bool `json:"is_current,omitempty"` IsCurrent *bool `json:"is_current,omitempty"`
@@ -594,6 +599,9 @@ type UserRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user // AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// IsBlocked If set to true then user is blocked and can't use the system
IsBlocked bool `json:"is_blocked"`
// Role User's NetBird account role // Role User's NetBird account role
Role string `json:"role"` Role string `json:"role"`
} }

View File

@@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
acMiddleware := middleware.NewAccessControl( acMiddleware := middleware.NewAccessControl(
authCfg.Audience, authCfg.Audience,
authCfg.UserIDClaim, authCfg.UserIDClaim,
accountManager.IsUserAdmin) accountManager.GetUser)
rootRouter := mux.NewRouter() rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware() metricsMiddleware := appMetrics.HTTPMiddleware()

View File

@@ -6,28 +6,30 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct { type AccessControl struct {
isUserAdmin IsUserAdminFunc
claimsExtract jwtclaims.ClaimsExtractor claimsExtract jwtclaims.ClaimsExtractor
getUser GetUser
} }
// NewAccessControl instance constructor // NewAccessControl instance constructor
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl { func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
return &AccessControl{ return &AccessControl{
isUserAdmin: isUserAdmin,
claimsExtract: *jwtclaims.NewClaimsExtractor( claimsExtract: *jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience), jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim), jwtclaims.WithUserIDClaim(userIDClaim),
), ),
getUser: getUser,
} }
} }
@@ -37,23 +39,29 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := a.claimsExtract.FromRequestContext(r) claims := a.claimsExtract.FromRequestContext(r)
ok, err := a.isUserAdmin(claims) user, err := a.getUser(claims)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return return
} }
if !ok {
if user.IsBlocked() {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
return
}
if !user.IsAdmin() {
switch r.Method { switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
if err != nil { if err != nil {
log.Debugf("Regex failed") log.Debugf("regex failed")
util.WriteError(status.Errorf(status.Internal, ""), w) util.WriteError(status.Errorf(status.Internal, ""), w)
return return
} }
if ok { if ok {
log.Debugf("Valid Path") log.Debugf("valid Path")
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
return return
} }

View File

@@ -63,7 +63,7 @@ var testAccount = &server.Account{
func initPATTestData() *PATHandler { func initPATTestData() *PATHandler {
return &PATHandler{ return &PATHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID { if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
} }
@@ -79,7 +79,7 @@ func initPATTestData() *PATHandler {
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil return testAccount, testAccount.Users[existingUserID], nil
}, },
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error { DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID { if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID) return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
} }
@@ -91,7 +91,7 @@ func initPATTestData() *PATHandler {
} }
return nil return nil
}, },
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if accountID != existingAccountID { if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
} }
@@ -103,7 +103,7 @@ func initPATTestData() *PATHandler {
} }
return testAccount.Users[existingUserID].PATs[existingTokenID], nil return testAccount.Users[existingUserID].PATs[existingTokenID], nil
}, },
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if accountID != existingAccountID { if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
} }

View File

@@ -61,6 +61,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
if req.AutoGroups == nil {
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
return
}
userRole := server.StrRoleToUserRole(req.Role) userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown { if userRole == server.UserRoleUnknown {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w) util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
@@ -71,7 +76,9 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
Id: userID, Id: userID,
Role: userRole, Role: userRole,
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked,
}) })
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@@ -214,7 +221,11 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
case "invited": case "invited":
userStatus = api.UserStatusInvited userStatus = api.UserStatusInvited
default: default:
userStatus = api.UserStatusDisabled userStatus = api.UserStatusBlocked
}
if user.IsBlocked {
userStatus = api.UserStatusBlocked
} }
isCurrent := user.ID == currenUserID isCurrent := user.ID == currenUserID
@@ -227,5 +238,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
Status: userStatus, Status: userStatus,
IsCurrent: &isCurrent, IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser, IsServiceUser: &user.IsServiceUser,
IsBlocked: user.IsBlocked,
} }
} }

View File

@@ -3,6 +3,7 @@ package http
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -31,16 +32,19 @@ var usersTestAccount = &server.Account{
Id: existingUserID, Id: existingUserID,
Role: "admin", Role: "admin",
IsServiceUser: false, IsServiceUser: false,
AutoGroups: []string{"group_1"},
}, },
regularUserID: { regularUserID: {
Id: regularUserID, Id: regularUserID,
Role: "user", Role: "user",
IsServiceUser: false, IsServiceUser: false,
AutoGroups: []string{"group_1"},
}, },
serviceUserID: { serviceUserID: {
Id: serviceUserID, Id: serviceUserID,
Role: "user", Role: "user",
IsServiceUser: true, IsServiceUser: true,
AutoGroups: []string{"group_1"},
}, },
}, },
} }
@@ -70,7 +74,7 @@ func initUsersTestData() *UsersHandler {
} }
return key, nil return key, nil
}, },
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error { DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
if targetUserID == notFoundUserID { if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID) return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
} }
@@ -79,6 +83,21 @@ func initUsersTestData() *UsersHandler {
} }
return nil return nil
}, },
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
if update.Id == notFoundUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
}
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
info, err := update.Copy().ToUserInfo(nil)
if err != nil {
return nil, err
}
return info, nil
},
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
@@ -145,6 +164,122 @@ func TestGetUsers(t *testing.T) {
} }
} }
func TestUpdateUser(t *testing.T) {
tt := []struct {
name string
expectedStatusCode int
requestType string
requestPath string
requestBody io.Reader
expectedUserID string
expectedRole string
expectedStatus string
expectedBlocked bool
expectedIsServiceUser bool
expectedGroups []string
}{
{
name: "Update_Block_User",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: true,
expectedRole: "user",
expectedStatus: "blocked",
expectedGroups: []string{"group_1"},
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
},
{
name: "Update_Change_Role_To_Admin",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_1"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
},
{
name: "Update_Groups",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_2", "group_3"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
},
{
name: "Should_Fail_Because_AutoGroups_Is_Absent",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusBadRequest,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_2", "group_3"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatusCode {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
if tc.expectedStatusCode == 200 {
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
respBody := &api.User{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("response content is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedUserID, respBody.Id)
assert.Equal(t, tc.expectedRole, respBody.Role)
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
for _, expectedGroup := range tc.expectedGroups {
exists := false
for _, actualGroup := range respBody.AutoGroups {
if expectedGroup == actualGroup {
exists = true
}
}
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
}
}
})
}
}
func TestCreateUser(t *testing.T) { func TestCreateUser(t *testing.T) {
name := "name" name := "name"
email := "email" email := "email"

View File

@@ -32,10 +32,10 @@ type Auth0Manager struct {
// Auth0ClientConfig auth0 manager client configurations // Auth0ClientConfig auth0 manager client configurations
type Auth0ClientConfig struct { type Auth0ClientConfig struct {
Audience string Audience string
AuthIssuer string AuthIssuer string `json:"-"`
ClientID string ClientID string
ClientSecret string ClientSecret string
GrantType string GrantType string `json:"-"`
} }
// auth0JWTRequest payload struct to request a JWT Token // auth0JWTRequest payload struct to request a JWT Token
@@ -110,7 +110,8 @@ type auth0Profile struct {
} }
// NewAuth0Manager creates a new instance of the Auth0Manager // NewAuth0Manager creates a new instance of the Auth0Manager
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) { func NewAuth0Manager(oidcConfig OIDCConfig, config Auth0ClientConfig,
appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5 httpTransport.MaxIdleConns = 5
@@ -121,17 +122,19 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
} }
helper := JsonParser{} helper := JsonParser{}
config.AuthIssuer = oidcConfig.TokenEndpoint
config.GrantType = "client_credentials"
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.Audience == "" || config.AuthIssuer == "" { if config.ClientID == "" {
return nil, fmt.Errorf("auth0 idp configuration is not complete") return nil, fmt.Errorf("auth0 IdP configuration is incomplete, clientID is missing")
} }
if config.GrantType != "client_credentials" { if config.ClientSecret == "" {
return nil, fmt.Errorf("auth0 idp configuration failed. Grant Type should be client_credentials") return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
} }
if !strings.HasPrefix(strings.ToLower(config.AuthIssuer), "https://") { if config.Audience == "" {
return nil, fmt.Errorf("auth0 idp configuration failed. AuthIssuer should contain https://") return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
} }
credentials := &Auth0Credentials{ credentials := &Auth0Credentials{

View File

@@ -459,26 +459,9 @@ func TestNewAuth0Manager(t *testing.T) {
testCase3Config := defaultTestConfig testCase3Config := defaultTestConfig
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com" testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
testCase3 := test{ for _, testCase := range []test{testCase1, testCase2} {
name: "Wrong Auth Issuer Format",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when wrong auth issuer format",
}
testCase4Config := defaultTestConfig
testCase4Config.GrantType = "spa"
testCase4 := test{
name: "Wrong Grant Type",
inputConfig: testCase4Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when wrong grant type",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{}) _, err := NewAuth0Manager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
}) })
} }

View File

@@ -37,12 +37,13 @@ type AzureManager struct {
// AzureClientConfig azure manager client configurations. // AzureClientConfig azure manager client configurations.
type AzureClientConfig struct { type AzureClientConfig struct {
ClientID string ClientID string
ClientSecret string ClientSecret string
GraphAPIEndpoint string ObjectID string
ObjectID string
TokenEndpoint string GraphAPIEndpoint string `json:"-"`
GrantType string TokenEndpoint string `json:"-"`
GrantType string `json:"-"`
} }
// AzureCredentials azure authentication information. // AzureCredentials azure authentication information.
@@ -74,7 +75,8 @@ type azureExtension struct {
} }
// NewAzureManager creates a new instance of the AzureManager. // NewAzureManager creates a new instance of the AzureManager.
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) { func NewAzureManager(oidcConfig OIDCConfig, config AzureClientConfig,
appMetrics telemetry.AppMetrics) (*AzureManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5 httpTransport.MaxIdleConns = 5
@@ -84,13 +86,20 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
} }
helper := JsonParser{} helper := JsonParser{}
config.TokenEndpoint = oidcConfig.TokenEndpoint
config.GraphAPIEndpoint = "https://graph.microsoft.com"
config.GrantType = "client_credentials"
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.GraphAPIEndpoint == "" || config.TokenEndpoint == "" { if config.ClientID == "" {
return nil, fmt.Errorf("azure idp configuration is not complete") return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
} }
if config.GrantType != "client_credentials" { if config.ClientSecret == "" {
return nil, fmt.Errorf("azure idp configuration failed. Grant Type should be client_credentials") return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
}
if config.ObjectID == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
} }
credentials := &AzureCredentials{ credentials := &AzureCredentials{

View File

@@ -19,12 +19,21 @@ type Manager interface {
GetUserByEmail(email string) ([]*UserData, error) GetUserByEmail(email string) ([]*UserData, error)
} }
// OIDCConfig specifies configuration for OpenID Connect provider
// These configurations are automatically loaded from the OIDC endpoint
type OIDCConfig struct {
Issuer string
TokenEndpoint string
}
// Config an idp configuration struct to be loaded from management server's config file // Config an idp configuration struct to be loaded from management server's config file
type Config struct { type Config struct {
ManagerType string ManagerType string
OIDCConfig OIDCConfig `json:"-"`
Auth0ClientCredentials Auth0ClientConfig Auth0ClientCredentials Auth0ClientConfig
KeycloakClientCredentials KeycloakClientConfig
AzureClientCredentials AzureClientConfig AzureClientCredentials AzureClientConfig
KeycloakClientCredentials KeycloakClientConfig
ZitadelClientCredentials ZitadelClientConfig
} }
// ManagerCredentials interface that authenticates using the credential of each type of idp // ManagerCredentials interface that authenticates using the credential of each type of idp
@@ -73,11 +82,13 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
case "none", "": case "none", "":
return nil, nil return nil, nil
case "auth0": case "auth0":
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics) return NewAuth0Manager(config.OIDCConfig, config.Auth0ClientCredentials, appMetrics)
case "azure": case "azure":
return NewAzureManager(config.AzureClientCredentials, appMetrics) return NewAzureManager(config.OIDCConfig, config.AzureClientCredentials, appMetrics)
case "keycloak": case "keycloak":
return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics) return NewKeycloakManager(config.OIDCConfig, config.KeycloakClientCredentials, appMetrics)
case "zitadel":
return NewZitadelManager(config.OIDCConfig, config.ZitadelClientCredentials, appMetrics)
default: default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
} }

View File

@@ -37,8 +37,8 @@ type KeycloakClientConfig struct {
ClientID string ClientID string
ClientSecret string ClientSecret string
AdminEndpoint string AdminEndpoint string
TokenEndpoint string TokenEndpoint string `json:"-"`
GrantType string GrantType string `json:"-"`
} }
// KeycloakCredentials keycloak authentication information. // KeycloakCredentials keycloak authentication information.
@@ -82,7 +82,8 @@ type keycloakProfile struct {
} }
// NewKeycloakManager creates a new instance of the KeycloakManager. // NewKeycloakManager creates a new instance of the KeycloakManager.
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) { func NewKeycloakManager(oidcConfig OIDCConfig, config KeycloakClientConfig,
appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5 httpTransport.MaxIdleConns = 5
@@ -92,13 +93,19 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
} }
helper := JsonParser{} helper := JsonParser{}
config.TokenEndpoint = oidcConfig.TokenEndpoint
config.GrantType = "client_credentials"
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" { if config.ClientID == "" {
return nil, fmt.Errorf("keycloak idp configuration is not complete") return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
} }
if config.GrantType != "client_credentials" { if config.ClientSecret == "" {
return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials") return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
}
if config.AdminEndpoint == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
} }
credentials := &KeycloakCredentials{ credentials := &KeycloakCredentials{

View File

@@ -46,19 +46,19 @@ func TestNewKeycloakManager(t *testing.T) {
assertErrFuncMessage: "should return error when field empty", assertErrFuncMessage: "should return error when field empty",
} }
testCase5Config := defaultTestConfig testCase3Config := defaultTestConfig
testCase5Config.GrantType = "authorization_code" testCase3Config.ClientSecret = ""
testCase5 := test{ testCase3 := test{
name: "Wrong GrantType", name: "Missing ClientSecret Configuration",
inputConfig: testCase5Config, inputConfig: testCase3Config,
assertErrFunc: require.Error, assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when wrong grant type", assertErrFuncMessage: "should return error when field empty",
} }
for _, testCase := range []test{testCase1, testCase2, testCase5} { for _, testCase := range []test{testCase1, testCase2, testCase3} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{}) _, err := NewKeycloakManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
}) })
} }

View File

@@ -0,0 +1,609 @@
package idp
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
)
// ZitadelManager zitadel manager client instance.
type ZitadelManager struct {
managementEndpoint string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// ZitadelClientConfig zitadel manager client configurations.
type ZitadelClientConfig struct {
ClientID string
ClientSecret string
GrantType string `json:"-"`
TokenEndpoint string `json:"-"`
ManagementEndpoint string `json:"-"`
}
// ZitadelCredentials zitadel authentication information.
type ZitadelCredentials struct {
clientConfig ZitadelClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// zitadelEmail specifies details of a user email.
type zitadelEmail struct {
Email string `json:"email"`
IsEmailVerified bool `json:"isEmailVerified"`
}
// zitadelUserInfo specifies user information.
type zitadelUserInfo struct {
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
DisplayName string `json:"displayName"`
}
// zitadelUser specifies profile details for user account.
type zitadelUser struct {
UserName string `json:"userName,omitempty"`
Profile zitadelUserInfo `json:"profile"`
Email zitadelEmail `json:"email"`
}
type zitadelAttributes map[string][]map[string]any
// zitadelMetadata holds additional user data.
type zitadelMetadata struct {
Key string `json:"key"`
Value string `json:"value"`
}
// zitadelProfile represents an zitadel user profile response.
type zitadelProfile struct {
ID string `json:"id"`
State string `json:"state"`
UserName string `json:"userName"`
PreferredLoginName string `json:"preferredLoginName"`
LoginNames []string `json:"loginNames"`
Human *zitadelUser `json:"human"`
Metadata []zitadelMetadata
}
// NewZitadelManager creates a new instance of the ZitadelManager.
func NewZitadelManager(oidcConfig OIDCConfig, config ZitadelClientConfig,
appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
config.TokenEndpoint = oidcConfig.TokenEndpoint
config.ManagementEndpoint = fmt.Sprintf("%s/management/v1", oidcConfig.Issuer)
config.GrantType = "client_credentials"
if config.ClientID == "" {
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, clientID is missing")
}
if config.ClientSecret == "" {
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, ClientSecret is missing")
}
credentials := &ZitadelCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &ZitadelManager{
managementEndpoint: config.ManagementEndpoint,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from zitadel.
func (zc *ZitadelCredentials) jwtStillValid() bool {
return !zc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(zc.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token.
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
data := url.Values{}
data.Set("client_id", zc.clientConfig.ClientID)
data.Set("client_secret", zc.clientConfig.ClientSecret)
data.Set("grant_type", zc.clientConfig.GrantType)
data.Set("scope", "urn:zitadel:iam:org:project:id:zitadel:aud")
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, zc.clientConfig.TokenEndpoint, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for zitadel idp manager")
resp, err := zc.httpClient.Do(req)
if err != nil {
if zc.appMetrics != nil {
zc.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
}
return resp, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds.
func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = zc.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = zc.helper.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the Zitadel Management API.
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
zc.mux.Lock()
defer zc.mux.Unlock()
if zc.appMetrics != nil {
zc.appMetrics.IDPMetrics().CountAuthenticate()
}
// reuse the token without requesting a new one if it is not expired,
// and if expiry time is sufficient time available to make a request.
if zc.jwtStillValid() {
return zc.jwtToken, nil
}
resp, err := zc.requestJWTToken()
if err != nil {
return zc.jwtToken, err
}
defer resp.Body.Close()
jwtToken, err := zc.parseRequestJWTResponse(resp.Body)
if err != nil {
return zc.jwtToken, err
}
zc.jwtToken = jwtToken
return zc.jwtToken, nil
}
// CreateUser creates a new user in zitadel Idp and sends an invite.
func (zm *ZitadelManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
payload, err := buildZitadelCreateUserRequestPayload(email, name)
if err != nil {
return nil, err
}
body, err := zm.post("users/human/_import", payload)
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountCreateUser()
}
var result struct {
UserID string `json:"userId"`
}
err = zm.helper.Unmarshal(body, &result)
if err != nil {
return nil, err
}
invite := true
appMetadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
// Add metadata to new user
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
if err != nil {
return nil, err
}
return zm.GetUserDataByID(result.UserID, appMetadata)
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
searchByEmail := zitadelAttributes{
"queries": {
{
"emailQuery": map[string]any{
"emailAddress": email,
"method": "TEXT_QUERY_METHOD_EQUALS",
},
},
},
}
payload, err := zm.helper.Marshal(searchByEmail)
if err != nil {
return nil, err
}
body, err := zm.post("users/_search", string(payload))
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profiles struct{ Result []zitadelProfile }
err = zm.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Result {
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
users = append(users, profile.userData())
}
return users, nil
}
// GetUserDataByID requests user data from zitadel via ID.
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := zm.get("users/"+userID, nil)
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var profile struct{ User zitadelProfile }
err = zm.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
metadata, err := zm.getUserMetadata(userID)
if err != nil {
return nil, err
}
profile.User.Metadata = metadata
return profile.User.userData(), nil
}
// GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
accounts, err := zm.GetAllAccounts()
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetAccount()
}
return accounts[accountID], nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
body, err := zm.post("users/_search", "")
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
var profiles struct{ Result []zitadelProfile }
err = zm.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Result {
// fetch user metadata
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
userData := profile.userData()
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
if appMetadata.WTPendingInvite == nil {
appMetadata.WTPendingInvite = new(bool)
}
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
metadata := zitadelAttributes{
"metadata": {
{
"key": wtAccountID,
"value": wtAccountIDValue,
},
{
"key": wtPendingInvite,
"value": wtPendingInviteValue,
},
},
}
payload, err := zm.helper.Marshal(metadata)
if err != nil {
return err
}
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
_, err = zm.post(resource, string(payload))
if err != nil {
return err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
return nil
}
// getUserMetadata requests user metadata from zitadel via ID.
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
body, err := zm.post(resource, "")
if err != nil {
return nil, err
}
var metadata struct{ Result []zitadelMetadata }
err = zm.helper.Unmarshal(body, &metadata)
if err != nil {
return nil, err
}
return metadata.Result, nil
}
// post perform Post requests.
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 && resp.StatusCode != 201 {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// get perform Get requests.
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s?%s", zm.managementEndpoint, resource, q.Encode())
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// value returns string represented by the base64 string value.
func (zm zitadelMetadata) value() string {
value, err := base64.StdEncoding.DecodeString(zm.Value)
if err != nil {
return ""
}
return string(value)
}
// userData construct user data from zitadel profile.
func (zp zitadelProfile) userData() *UserData {
var (
email string
name string
wtAccountIDValue string
wtPendingInviteValue bool
)
for _, metadata := range zp.Metadata {
if metadata.Key == wtAccountID {
wtAccountIDValue = metadata.value()
}
if metadata.Key == wtPendingInvite {
value, err := strconv.ParseBool(metadata.value())
if err == nil {
wtPendingInviteValue = value
}
}
}
// Obtain the email for the human account and the login name,
// for the machine account.
if zp.Human != nil {
email = zp.Human.Email.Email
name = zp.Human.Profile.DisplayName
} else {
if len(zp.LoginNames) > 0 {
email = zp.LoginNames[0]
name = zp.LoginNames[0]
}
}
return &UserData{
Email: email,
Name: name,
ID: zp.ID,
AppMetadata: AppMetadata{
WTAccountID: wtAccountIDValue,
WTPendingInvite: &wtPendingInviteValue,
},
}
}
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
var firstName, lastName string
words := strings.Fields(name)
if n := len(words); n > 0 {
firstName = strings.Join(words[:n-1], " ")
lastName = words[n-1]
}
req := &zitadelUser{
UserName: name,
Profile: zitadelUserInfo{
FirstName: strings.TrimSpace(firstName),
LastName: strings.TrimSpace(lastName),
DisplayName: name,
},
Email: zitadelEmail{
Email: email,
IsEmailVerified: false,
},
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}

View File

@@ -0,0 +1,486 @@
package idp
import (
"fmt"
"io"
"strings"
"testing"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewZitadelManager(t *testing.T) {
type test struct {
name string
inputConfig ZitadelClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := ZitadelClientConfig{
ClientID: "client_id",
ClientSecret: "client_secret",
GrantType: "client_credentials",
TokenEndpoint: "http://localhost/oauth/v2/token",
ManagementEndpoint: "http://localhost/management/v1",
}
testCase1 := test{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
}
testCase2Config := defaultTestConfig
testCase2Config.ClientID = ""
testCase2 := test{
name: "Missing ClientID Configuration",
inputConfig: testCase2Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase3Config := defaultTestConfig
testCase3Config.ClientSecret = ""
testCase3 := test{
name: "Missing ClientSecret Configuration",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2, testCase3} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewZitadelManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}
}
type mockZitadelCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestZitadelRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct {
name string
inputCode int
inputRespBody string
helper ManagerHelper
expectedFuncExitErrDiff error
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
requestJWTTokenTesttCase1 := requestJWTTokenTest{
name: "Good JWT Response",
inputCode: 200,
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
}
requestJWTTokenTestCase2 := requestJWTTokenTest{
name: "Request Bad Status Code",
inputCode: 400,
inputRespBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
expectedToken: "",
}
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputRespBody,
code: testCase.inputCode,
}
config := ZitadelClientConfig{}
creds := ZitadelCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
resp, err := creds.requestJWTToken()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
} else {
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err, "unable to read the response body")
jwtToken := JWTToken{}
err = testCase.helper.Unmarshal(body, &jwtToken)
assert.NoError(t, err, "unable to parse the json input")
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
}
})
}
}
func TestZitadelParseRequestJWTResponse(t *testing.T) {
type parseRequestJWTResponseTest struct {
name string
inputRespBody string
helper ManagerHelper
expectedToken string
expectedExpiresIn int
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
exp := 100
token := newTestJWT(t, exp)
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
name: "Parse Good JWT Body",
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
expectedExpiresIn: exp,
assertErrFunc: assert.NoError,
assertErrFuncMessage: "no error was expected",
}
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
name: "Parse Bad json JWT Body",
inputRespBody: "",
helper: JsonParser{},
expectedToken: "",
expectedExpiresIn: 0,
assertErrFunc: assert.Error,
assertErrFuncMessage: "json error was expected",
}
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
config := ZitadelClientConfig{}
creds := ZitadelCredentials{
clientConfig: config,
helper: testCase.helper,
}
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
})
}
}
func TestZitadelJwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
inputTime time.Time
expectedResult bool
message string
}
jwtStillValidTestCase1 := jwtStillValidTest{
name: "JWT still valid",
inputTime: time.Now().Add(10 * time.Second),
expectedResult: true,
message: "should be true",
}
jwtStillValidTestCase2 := jwtStillValidTest{
name: "JWT is invalid",
inputTime: time.Now(),
expectedResult: false,
message: "should be false",
}
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
config := ZitadelClientConfig{}
creds := ZitadelCredentials{
clientConfig: config,
}
creds.jwtToken.expiresInTime = testCase.inputTime
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
})
}
}
func TestZitadelAuthenticate(t *testing.T) {
type authenticateTest struct {
name string
inputCode int
inputResBody string
inputExpireToken time.Time
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
authenticateTestCase1 := authenticateTest{
name: "Get Cached token",
inputExpireToken: time.Now().Add(30 * time.Second),
helper: JsonParser{},
expectedFuncExitErrDiff: nil,
expectedCode: 200,
expectedToken: "",
}
authenticateTestCase2 := authenticateTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
authenticateTestCase3 := authenticateTest{
name: "Get Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := ZitadelClientConfig{}
creds := ZitadelCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
})
}
}
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
appMetadata: appMetadata,
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &ZitadelManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestZitadelProfile(t *testing.T) {
type azureProfileTest struct {
name string
invite bool
inputProfile zitadelProfile
expectedUserData UserData
}
azureProfileTestCase1 := azureProfileTest{
name: "User Request",
invite: false,
inputProfile: zitadelProfile{
ID: "test1",
State: "USER_STATE_ACTIVE",
UserName: "test1@mail.com",
PreferredLoginName: "test1@mail.com",
LoginNames: []string{
"test1@mail.com",
},
Human: &zitadelUser{
Profile: zitadelUserInfo{
FirstName: "ZITADEL",
LastName: "Admin",
DisplayName: "ZITADEL Admin",
},
Email: zitadelEmail{
Email: "test1@mail.com",
IsEmailVerified: true,
},
},
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "ZmFsc2U=",
},
},
},
expectedUserData: UserData{
ID: "test1",
Name: "ZITADEL Admin",
Email: "test1@mail.com",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase2 := azureProfileTest{
name: "Service User Request",
invite: true,
inputProfile: zitadelProfile{
ID: "test2",
State: "USER_STATE_ACTIVE",
UserName: "machine",
PreferredLoginName: "machine",
LoginNames: []string{
"machine",
},
Human: nil,
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "dHJ1ZQ==",
},
},
},
expectedUserData: UserData{
ID: "test2",
Name: "machine",
Email: "machine",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData()
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
})
}
}

View File

@@ -19,7 +19,7 @@ type MockAccountManager struct {
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
AccountExistsFunc func(accountId string) (*bool, error) AccountExistsFunc func(accountId string) (*bool, error)
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error) GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
@@ -60,11 +60,11 @@ type MockAccountManager struct {
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
@@ -190,33 +190,33 @@ func (am *MockAccountManager) MarkPATUsed(pat string) error {
} }
// CreatePAT mock implementation of GetPAT from server.AccountManager interface // CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil { if am.CreatePATFunc != nil {
return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn) return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
} }
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
} }
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface // DeletePAT mock implementation of DeletePAT from server.AccountManager interface
func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if am.DeletePATFunc != nil { if am.DeletePATFunc != nil {
return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID) return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
} }
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented") return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
} }
// GetPAT mock implementation of GetPAT from server.AccountManager interface // GetPAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if am.GetPATFunc != nil { if am.GetPATFunc != nil {
return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID) return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
} }
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface // GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if am.GetAllPATsFunc != nil { if am.GetAllPATsFunc != nil {
return am.GetAllPATsFunc(accountID, executingUserID, targetUserID) return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
} }
@@ -385,12 +385,12 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented") return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
} }
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface // GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
if am.IsUserAdminFunc != nil { if am.GetUserFunc != nil {
return am.IsUserAdminFunc(claims) return am.GetUserFunc(claims)
} }
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented") return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented")
} }
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager // UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
@@ -493,9 +493,9 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
} }
// DeleteUser mocks DeleteUser of the AccountManager interface // DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error { func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil { if am.DeleteUserFunc != nil {
return am.DeleteUserFunc(accountID, executingUserID, targetUserID) return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
} }
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
} }

View File

@@ -605,6 +605,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
} }
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
}
if peerLoginExpired(peer, account) { if peerLoginExpired(peer, account) {
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
} }
@@ -644,6 +649,11 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
} }
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
}
updateRemotePeers := false updateRemotePeers := false
if peerLoginExpired(peer, account) { if peerLoginExpired(peer, account) {
err = checkAuth(login.UserID, peer) err = checkAuth(login.UserID, peer)
@@ -676,6 +686,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
} }
func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error {
if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID)
if err != nil {
return status.Errorf(status.PermissionDenied, "user doesn't exist")
}
if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked")
}
}
return nil
}
func checkAuth(loginUserID string, peer *Peer) error { func checkAuth(loginUserID string, peer *Peer) error {
if loginUserID == "" { if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided. // absence of a user ID indicates that JWT wasn't provided.

View File

@@ -51,15 +51,22 @@ type User struct {
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string AutoGroups []string
PATs map[string]*PersonalAccessToken PATs map[string]*PersonalAccessToken
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
} }
// IsAdmin returns true if user is an admin, false otherwise // IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
}
// IsAdmin returns true if the user is an admin, false otherwise
func (u *User) IsAdmin() bool { func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin return u.Role == UserRoleAdmin
} }
// toUserInfo converts a User object to a UserInfo object. // ToUserInfo converts a User object to a UserInfo object.
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) { func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
autoGroups := u.AutoGroups autoGroups := u.AutoGroups
if autoGroups == nil { if autoGroups == nil {
autoGroups = []string{} autoGroups = []string{}
@@ -74,6 +81,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
AutoGroups: u.AutoGroups, AutoGroups: u.AutoGroups,
Status: string(UserStatusActive), Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser, IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
}, nil }, nil
} }
if userData.ID != u.Id { if userData.ID != u.Id {
@@ -93,6 +101,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: string(userStatus), Status: string(userStatus),
IsServiceUser: u.IsServiceUser, IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
}, nil }, nil
} }
@@ -113,6 +122,7 @@ func (u *User) Copy() *User {
IsServiceUser: u.IsServiceUser, IsServiceUser: u.IsServiceUser,
ServiceUserName: u.ServiceUserName, ServiceUserName: u.ServiceUserName,
PATs: pats, PATs: pats,
Blocked: u.Blocked,
} }
} }
@@ -138,7 +148,7 @@ func NewAdminUser(id string) *User {
} }
// createServiceUser creates a new service user under the given account. // createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) { func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -147,7 +157,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
} }
executingUser := account.Users[executingUserID] executingUser := account.Users[initiatorUserID]
if executingUser == nil { if executingUser == nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
@@ -166,7 +176,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
} }
meta := map[string]any{"name": newUser.ServiceUserName} meta := map[string]any{"name": newUser.ServiceUserName}
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) am.storeEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
return &UserInfo{ return &UserInfo{
ID: newUser.Id, ID: newUser.Id,
@@ -212,7 +222,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
} }
if user != nil { if user != nil {
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account") return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
} }
users, err := am.idpManager.GetUserByEmail(invite.Email) users, err := am.idpManager.GetUserByEmail(invite.Email)
@@ -221,7 +231,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
} }
if len(users) > 0 { if len(users) > 0 {
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account") return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
} }
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID) idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
@@ -249,12 +259,27 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil) am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
return newUser.toUserInfo(idpUser) return newUser.ToUserInfo(idpUser)
} }
// GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) {
account, _, err := am.GetAccountFromToken(claims)
if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
user, ok := account.Users[claims.UserId]
if !ok {
return nil, status.Errorf(status.NotFound, "user not found")
}
return user, nil
}
// DeleteUser deletes a user from the given account. // DeleteUser deletes a user from the given account.
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error { func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -268,7 +293,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
executingUser := account.Users[executingUserID] executingUser := account.Users[initiatorUserID]
if executingUser == nil { if executingUser == nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
@@ -281,7 +306,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
} }
meta := map[string]any{"name": targetUser.ServiceUserName} meta := map[string]any{"name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta) am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
delete(account.Users, targetUserID) delete(account.Users, targetUserID)
@@ -294,7 +319,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
} }
// CreatePAT creates a new PAT for the given user // CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -316,12 +341,12 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
return nil, status.Errorf(status.NotFound, "targetUser not found") return nil, status.Errorf(status.NotFound, "targetUser not found")
} }
executingUser := account.Users[executingUserID] executingUser := account.Users[initiatorUserID]
if targetUser == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
} }
@@ -338,13 +363,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
} }
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta) am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
return pat, nil return pat, nil
} }
// DeletePAT deletes a specific PAT from a user // DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -358,12 +383,12 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
executingUser := account.Users[executingUserID] executingUser := account.Users[initiatorUserID]
if targetUser == nil { if targetUser == nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
} }
@@ -382,7 +407,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
} }
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
delete(targetUser.PATs, tokenID) delete(targetUser.PATs, tokenID)
@@ -394,7 +419,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
} }
// GetPAT returns a specific PAT from a user // GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -408,12 +433,12 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
executingUser := account.Users[executingUserID] executingUser := account.Users[initiatorUserID]
if targetUser == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser") return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
} }
@@ -426,7 +451,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
} }
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) { func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -440,12 +465,12 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
executingUser := account.Users[executingUserID] executingUser := account.Users[initiatorUserID]
if targetUser == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
} }
@@ -457,9 +482,9 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
return pats, nil return pats, nil
} }
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. // SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now. // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) { func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -472,56 +497,102 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
return nil, err return nil, err
} }
initiatorUser, err := account.FindUser(initiatorUserID)
if err != nil {
return nil, err
}
if !initiatorUser.IsAdmin() || initiatorUser.IsBlocked() {
return nil, status.Errorf(status.PermissionDenied, "only admins are authorized to perform user update operations")
}
oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
}
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && update.Role != UserRoleAdmin {
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
}
// only auto groups, revoked status, and name can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
for _, newGroupID := range update.AutoGroups { for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok { if _, ok := account.Groups[newGroupID]; !ok {
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id) newGroupID, update.Id)
} }
} }
oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(status.NotFound, "update not found")
}
// only auto groups, revoked status, and name can be updated for now
newUser := oldUser.Copy()
newUser.AutoGroups = update.AutoGroups newUser.AutoGroups = update.AutoGroups
newUser.Role = update.Role
account.Users[newUser.Id] = newUser account.Users[newUser.Id] = newUser
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
return nil, err
}
var peerIDs []string
for _, peer := range blockedPeers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return nil, err
}
}
am.peersUpdateManager.CloseChannels(peerIDs)
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
return nil, err
}
}
if err = am.Store.SaveAccount(account); err != nil { if err = am.Store.SaveAccount(account); err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
// store activity logs
if oldUser.Role != newUser.Role { if oldUser.Role != newUser.Role {
am.storeEvent(userID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
} }
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) if update.AutoGroups != nil {
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
for _, g := range removedGroups { addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
group := account.GetGroup(g) for _, g := range removedGroups {
if group != nil { group := account.GetGroup(g)
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser, if group != nil {
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
} else { map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) } else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
} }
} for _, g := range addedGroups {
group := account.GetGroup(g)
for _, g := range addedGroups { if group != nil {
group := account.GetGroup(g) am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
if group != nil { map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser, }
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
} }
} }
}() }()
if !isNil(am.idpManager) && !newUser.IsServiceUser { if !isNil(am.idpManager) && !newUser.IsServiceUser {
@@ -532,9 +603,9 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
if userData == nil { if userData == nil {
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id) return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
} }
return newUser.toUserInfo(userData) return newUser.ToUserInfo(userData)
} }
return newUser.toUserInfo(nil) return newUser.ToUserInfo(nil)
} }
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
@@ -574,21 +645,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
return account, nil return account, nil
} }
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, _, err := am.GetAccountFromToken(claims)
if err != nil {
return false, fmt.Errorf("get account: %v", err)
}
user, ok := account.Users[claims.UserId]
if !ok {
return false, status.Errorf(status.NotFound, "user not found")
}
return user.Role == UserRoleAdmin, nil
}
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role. // based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) { func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
@@ -625,7 +681,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
// if user is not an admin then show only current user and do not show other users // if user is not an admin then show only current user and do not show other users
continue continue
} }
info, err := accountUser.toUserInfo(nil) info, err := accountUser.ToUserInfo(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -642,7 +698,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
var info *UserInfo var info *UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.toUserInfo(queriedUser) info, err = localUser.ToUserInfo(queriedUser)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -265,6 +266,7 @@ func TestUser_Copy(t *testing.T) {
LastUsed: time.Now(), LastUsed: time.Now(),
}, },
}, },
Blocked: false,
} }
err := validateStruct(user) err := validateStruct(user)
@@ -288,7 +290,7 @@ func validateStruct(s interface{}) (err error) {
field := structVal.Field(i) field := structVal.Field(i)
fieldName := structType.Field(i).Name fieldName := structType.Field(i).Name
isSet := field.IsValid() && !field.IsZero() isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool")
if !isSet { if !isSet {
err = fmt.Errorf("%v%s in not set; ", err, fieldName) err = fmt.Errorf("%v%s in not set; ", err, fieldName)
@@ -440,7 +442,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
assert.Errorf(t, err, "Regular users can not be deleted (yet)") assert.Errorf(t, err, "Regular users can not be deleted (yet)")
} }
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
@@ -458,42 +460,23 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
UserId: mockUserID, UserId: mockUserID,
} }
ok, err := am.IsUserAdmin(claims) user, err := am.GetUser(claims)
if err != nil { if err != nil {
t.Fatalf("Error when checking user role: %s", err) t.Fatalf("Error when checking user role: %s", err)
} }
assert.True(t, ok) assert.Equal(t, mockUserID, user.Id)
assert.True(t, user.IsAdmin())
assert.False(t, user.IsBlocked())
} }
func TestUser_IsUserAdmin_ForUser(t *testing.T) { func TestUser_IsAdmin(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
Role: "user",
}
err := store.SaveAccount(account) user := NewAdminUser(mockUserID)
if err != nil { assert.True(t, user.IsAdmin())
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ user = NewRegularUser(mockUserID)
Store: store, assert.False(t, user.IsAdmin())
eventStore: &activity.InMemoryEventStore{},
}
claims := jwtclaims.AuthorizationClaims{
UserId: mockUserID,
}
ok, err := am.IsUserAdmin(claims)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.False(t, ok)
} }
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
@@ -550,3 +533,103 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
assert.Equal(t, 1, len(users)) assert.Equal(t, 1, len(users))
assert.Equal(t, mockServiceUserID, users[0].ID) assert.Equal(t, mockServiceUserID, users[0].ID)
} }
func TestDefaultAccountManager_SaveUser(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
regularUserID := "regularUser"
tt := []struct {
name string
adminInitiator bool
update *User
expectedErr bool
}{
{
name: "Should_Fail_To_Update_Admin_Role",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleUser,
Blocked: false,
},
}, {
name: "Should_Fail_When_Admin_Blocks_Themselves",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Fail_To_Update_Non_Existing_User",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
expectedErr: true,
adminInitiator: false,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Update_User",
expectedErr: false,
adminInitiator: true,
update: &User{
Id: regularUserID,
Role: UserRoleAdmin,
Blocked: true,
},
},
}
for _, tc := range tt {
// create an account and an admin user
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.io")
if err != nil {
t.Fatal(err)
}
// create a regular user
account.Users[regularUserID] = NewRegularUser(regularUserID)
err = manager.Store.SaveAccount(account)
if err != nil {
t.Fatal(err)
}
initiatorID := userID
if !tc.adminInitiator {
initiatorID = regularUserID
}
updated, err := manager.SaveUser(account.Id, initiatorID, tc.update)
if tc.expectedErr {
require.Errorf(t, err, "expecting SaveUser to throw an error")
} else {
require.NoError(t, err, "expecting SaveUser not to throw an error")
assert.NotNil(t, updated)
assert.Equal(t, string(tc.update.Role), updated.Role)
assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked)
}
}
}

View File

@@ -36,6 +36,7 @@ download_release_binary() {
echo "Installing $1 from $DOWNLOAD_URL" echo "Installing $1 from $DOWNLOAD_URL"
cd /tmp && curl -LO "$DOWNLOAD_URL" cd /tmp && curl -LO "$DOWNLOAD_URL"
if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then
INSTALL_DIR="/Applications/NetBird UI.app" INSTALL_DIR="/Applications/NetBird UI.app"
@@ -43,8 +44,9 @@ download_release_binary() {
unzip -q -o "$BINARY_NAME" unzip -q -o "$BINARY_NAME"
mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR" mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR"
else else
sudo mkdir -p "$INSTALL_DIR"
tar -xzvf "$BINARY_NAME" tar -xzvf "$BINARY_NAME"
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR" sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/"
fi fi
} }

View File

@@ -2,44 +2,44 @@ package sharedsock
import "golang.org/x/net/bpf" import "golang.org/x/net/bpf"
// STUNFilter implements BPFFilter by filtering on STUN packets. // IncomingSTUNFilter implements BPFFilter and filters out anything but incoming STUN packets to a specified destination port.
// Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard). // Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard).
type STUNFilter struct { type IncomingSTUNFilter struct {
} }
// NewSTUNFilter creates an instance of a STUNFilter // NewIncomingSTUNFilter creates an instance of a IncomingSTUNFilter
func NewSTUNFilter() BPFFilter { func NewIncomingSTUNFilter() BPFFilter {
return &STUNFilter{} return &IncomingSTUNFilter{}
} }
// GetInstructions returns raw BPF instructions for ipv4 and ipv6 that filter out anything but STUN packets // GetInstructions returns raw BPF instructions for ipv4 and ipv6 that filter out anything but STUN packets
func (sf STUNFilter) GetInstructions(port uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) { func (filter *IncomingSTUNFilter) GetInstructions(dstPort uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) {
raw4, err = rawInstructions(22, 32, port) raw4, err = rawInstructions(22, 32, dstPort)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
raw6, err = rawInstructions(2, 12, port) raw6, err = rawInstructions(2, 12, dstPort)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return raw4, raw6, nil return raw4, raw6, nil
} }
func rawInstructions(portOff, cookieOff, port uint32) ([]bpf.RawInstruction, error) { func rawInstructions(dstPortOff, cookieOff, dstPort uint32) ([]bpf.RawInstruction, error) {
// UDP raw socket for ipv4 receives the rcvdPacket with IP headers // UDP raw socket for ipv4 receives the rcvdPacket with IP headers
// UDP raw socket for ipv6 receives the rcvdPacket with UDP headers // UDP raw socket for ipv6 receives the rcvdPacket with UDP headers
instructions := []bpf.Instruction{ instructions := []bpf.Instruction{
// Load the source port from the UDP header (offset 22 for ipv4 and 2 for ipv6) // Load the destination port from the UDP header (offset 22 for ipv4 and 2 for ipv6)
bpf.LoadAbsolute{Off: portOff, Size: 2}, bpf.LoadAbsolute{Off: dstPortOff, Size: 2},
// Check if the source port is equal to the specified `port`. If not, skip the next 3 instructions. // Check if the destination port is equal to the specified `dstPort`. If not, skip the next 3 instructions.
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: port, SkipTrue: 3}, bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: dstPort, SkipTrue: 3},
// Load the 4-byte value (magic cookie) from the UDP payload (offset 32 for ipv4 and 12 for ipv6) // Load the 4-byte value (magic cookie) from the UDP payload (offset 32 for ipv4 and 12 for ipv6)
bpf.LoadAbsolute{Off: cookieOff, Size: 4}, bpf.LoadAbsolute{Off: cookieOff, Size: 4},
// Check if the loaded value is equal to the `magicCookie`. If not, skip the next instruction. // Check if the loaded value is equal to the `magicCookie`. If not, skip the next instruction.
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: magicCookie, SkipTrue: 1}, bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: magicCookie, SkipTrue: 1},
// If both the port and the magic cookie match, return a positive value (0xffffffff) // If both the dstPort and the magic cookie match, return a positive value (0xffffffff)
bpf.RetConstant{Val: 0xffffffff}, bpf.RetConstant{Val: 0xffffffff},
// If either the port or the magic cookie doesn't match, return 0 // If either the dstPort or the magic cookie doesn't match, return 0
bpf.RetConstant{Val: 0}, bpf.RetConstant{Val: 0},
} }

View File

@@ -2,7 +2,7 @@
package sharedsock package sharedsock
// NewSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux // NewIncomingSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux
func NewSTUNFilter() BPFFilter { func NewIncomingSTUNFilter() BPFFilter {
return nil return nil
} }

View File

@@ -27,7 +27,8 @@ import (
var ErrSharedSockStopped = fmt.Errorf("shared socked stopped") var ErrSharedSockStopped = fmt.Errorf("shared socked stopped")
// SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered // SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered
// by BPF instructions (e.g., STUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)). // by BPF instructions (e.g., IncomingSTUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)).
// It is meant to be used when sharing a port with some other process.
type SharedSocket struct { type SharedSocket struct {
ctx context.Context ctx context.Context
conn4 *socket.Conn conn4 *socket.Conn

View File

@@ -21,7 +21,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) {
// create raw socket on a port // create raw socket on a port
testingPort := 51821 testingPort := 51821
rawSock, err := Listen(testingPort, NewSTUNFilter()) rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second)) err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second))
require.NoError(t, err, "unable to set deadline, error: %s", err) require.NoError(t, err, "unable to set deadline, error: %s", err)
@@ -76,7 +76,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) {
func TestShouldNotReadNonSTUNPackets(t *testing.T) { func TestShouldNotReadNonSTUNPackets(t *testing.T) {
testingPort := 39439 testingPort := 39439
rawSock, err := Listen(testingPort, NewSTUNFilter()) rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
defer rawSock.Close() defer rawSock.Close()
@@ -110,7 +110,7 @@ func TestWriteTo(t *testing.T) {
defer udpListener.Close() defer udpListener.Close()
testingPort := 39440 testingPort := 39440
rawSock, err := Listen(testingPort, NewSTUNFilter()) rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
defer rawSock.Close() defer rawSock.Close()
@@ -144,7 +144,7 @@ func TestWriteTo(t *testing.T) {
} }
func TestSharedSocket_Close(t *testing.T) { func TestSharedSocket_Close(t *testing.T) {
rawSock, err := Listen(39440, NewSTUNFilter()) rawSock, err := Listen(39440, NewIncomingSTUNFilter())
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
errGrp := errgroup.Group{} errGrp := errgroup.Group{}

View File

@@ -1,4 +1,4 @@
//go:build !linux //go:build !linux || android
package sharedsock package sharedsock

View File

@@ -350,7 +350,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
} else if err != nil { } else if err != nil {
return err return err
} }
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key) log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg) decryptedMessage, err := c.decryptMessage(msg)
if err != nil { if err != nil {