mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
12 Commits
v0.19.0
...
separate_p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b43d7e8ef | ||
|
|
dcc83c8741 | ||
|
|
d56669ec2e | ||
|
|
e3d2b6a408 | ||
|
|
9f758b2015 | ||
|
|
2c50d7af1e | ||
|
|
e4c28f64fa | ||
|
|
6f2c4078ef | ||
|
|
f4ec1699ca | ||
|
|
fea53b2f0f | ||
|
|
60e6d0890a | ||
|
|
cb12e2da21 |
@@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
@@ -45,12 +46,16 @@ var loginCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
PreSharedKey: &preSharedKey,
|
}
|
||||||
})
|
if preSharedKey != "" {
|
||||||
|
ic.PreSharedKey = &preSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := internal.UpdateOrCreateConfig(ic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
}
|
}
|
||||||
@@ -106,7 +111,7 @@ var loginCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
openURL(cmd, loginResp.VerificationURIComplete)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -185,7 +190,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||||
@@ -199,11 +204,16 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete string) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||||
|
var codeMsg string
|
||||||
|
if !strings.Contains(verificationURIComplete, userCode) {
|
||||||
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
|
}
|
||||||
|
|
||||||
err := open.Run(verificationURIComplete)
|
err := open.Run(verificationURIComplete)
|
||||||
cmd.Printf("Please do the SSO login in your browser. \n" +
|
cmd.Printf("Please do the SSO login in your browser. \n" +
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
" " + verificationURIComplete + " \n\n")
|
" " + verificationURIComplete + " " + codeMsg + " \n\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,14 +78,18 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
PreSharedKey: &preSharedKey,
|
|
||||||
NATExternalIPs: natExternalIPs,
|
NATExternalIPs: natExternalIPs,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
})
|
}
|
||||||
|
if preSharedKey != "" {
|
||||||
|
ic.PreSharedKey = &preSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := internal.UpdateOrCreateConfig(ic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
}
|
}
|
||||||
@@ -172,7 +176,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -210,7 +209,7 @@ func (e *Engine) Start() error {
|
|||||||
e.udpMux = udpMux
|
e.udpMux = udpMux
|
||||||
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
|
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
|
||||||
} else {
|
} else {
|
||||||
rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewSTUNFilter())
|
rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewIncomingSTUNFilter())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -247,7 +246,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
peerPubKey := p.GetWgPubKey()
|
peerPubKey := p.GetWgPubKey()
|
||||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
||||||
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") {
|
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
||||||
modified = append(modified, p)
|
modified = append(modified, p)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -758,9 +757,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
|||||||
|
|
||||||
// we might have received new STUN and TURN servers meanwhile, so update them
|
// we might have received new STUN and TURN servers meanwhile, so update them
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
conf := conn.GetConf()
|
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||||
conf.StunTurn = append(e.STUNs, e.TURNs...)
|
|
||||||
conn.UpdateConf(conf)
|
|
||||||
e.syncMsgMux.Unlock()
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
err := conn.Open()
|
err := conn.Open()
|
||||||
@@ -789,9 +786,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
stunTurn = append(stunTurn, e.STUNs...)
|
stunTurn = append(stunTurn, e.STUNs...)
|
||||||
stunTurn = append(stunTurn, e.TURNs...)
|
stunTurn = append(stunTurn, e.TURNs...)
|
||||||
|
|
||||||
proxyConfig := proxy.Config{
|
wgConfig := peer.WgConfig{
|
||||||
RemoteKey: pubKey,
|
RemoteKey: pubKey,
|
||||||
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
|
WgListenPort: e.config.WgPort,
|
||||||
WgInterface: e.wgInterface,
|
WgInterface: e.wgInterface,
|
||||||
AllowedIps: allowedIPs,
|
AllowedIps: allowedIPs,
|
||||||
PreSharedKey: e.config.PreSharedKey,
|
PreSharedKey: e.config.PreSharedKey,
|
||||||
@@ -808,7 +805,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
UDPMux: e.udpMux.UDPMuxDefault,
|
UDPMux: e.udpMux.UDPMuxDefault,
|
||||||
UDPMuxSrflx: e.udpMux,
|
UDPMuxSrflx: e.udpMux,
|
||||||
ProxyConfig: proxyConfig,
|
WgConfig: wgConfig,
|
||||||
LocalWgPort: e.config.WgPort,
|
LocalWgPort: e.config.WgPort,
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||||
|
|||||||
@@ -148,6 +148,11 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
|||||||
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
||||||
|
if deviceCode.VerificationURIComplete == "" {
|
||||||
|
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||||
|
}
|
||||||
|
|
||||||
return deviceCode, err
|
return deviceCode, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/iface/bind"
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
@@ -28,8 +28,18 @@ const (
|
|||||||
|
|
||||||
iceKeepAliveDefault = 4 * time.Second
|
iceKeepAliveDefault = 4 * time.Second
|
||||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||||
|
|
||||||
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WgConfig struct {
|
||||||
|
WgListenPort int
|
||||||
|
RemoteKey string
|
||||||
|
WgInterface *iface.WGIface
|
||||||
|
AllowedIps string
|
||||||
|
PreSharedKey *wgtypes.Key
|
||||||
|
}
|
||||||
|
|
||||||
// ConnConfig is a peer Connection configuration
|
// ConnConfig is a peer Connection configuration
|
||||||
type ConnConfig struct {
|
type ConnConfig struct {
|
||||||
|
|
||||||
@@ -48,7 +58,7 @@ type ConnConfig struct {
|
|||||||
|
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
|
||||||
ProxyConfig proxy.Config
|
WgConfig WgConfig
|
||||||
|
|
||||||
UDPMux ice.UDPMux
|
UDPMux ice.UDPMux
|
||||||
UDPMuxSrflx ice.UniversalUDPMux
|
UDPMuxSrflx ice.UniversalUDPMux
|
||||||
@@ -103,7 +113,7 @@ type Conn struct {
|
|||||||
|
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
|
|
||||||
proxy proxy.Proxy
|
proxy *WireGuardProxy
|
||||||
remoteModeCh chan ModeMessage
|
remoteModeCh chan ModeMessage
|
||||||
meta meta
|
meta meta
|
||||||
|
|
||||||
@@ -127,9 +137,14 @@ func (conn *Conn) GetConf() ConnConfig {
|
|||||||
return conn.config
|
return conn.config
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConf updates the connection config
|
// WgConfig returns the WireGuard config
|
||||||
func (conn *Conn) UpdateConf(conf ConnConfig) {
|
func (conn *Conn) WgConfig() WgConfig {
|
||||||
conn.config = conf
|
return conn.config.WgConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStunTurn update the turn and stun addresses
|
||||||
|
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
|
||||||
|
conn.config.StunTurn = turnStun
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
@@ -240,12 +255,12 @@ func readICEAgentConfigProperties() (time.Duration, time.Duration) {
|
|||||||
func (conn *Conn) Open() error {
|
func (conn *Conn) Open() error {
|
||||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
peerState := State{
|
||||||
|
PubKey: conn.config.Key,
|
||||||
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatusUpdate: time.Now(),
|
||||||
peerState.ConnStatus = conn.status
|
ConnStatus: conn.status,
|
||||||
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||||
@@ -300,10 +315,11 @@ func (conn *Conn) Open() error {
|
|||||||
defer conn.notifyDisconnected()
|
defer conn.notifyDisconnected()
|
||||||
conn.mu.Unlock()
|
conn.mu.Unlock()
|
||||||
|
|
||||||
peerState = State{PubKey: conn.config.Key}
|
peerState = State{
|
||||||
|
PubKey: conn.config.Key,
|
||||||
peerState.ConnStatus = conn.status
|
ConnStatus: conn.status,
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatusUpdate: time.Now(),
|
||||||
|
}
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||||
@@ -334,19 +350,12 @@ func (conn *Conn) Open() error {
|
|||||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||||
}
|
}
|
||||||
// the ice connection has been established successfully so we are ready to start the proxy
|
// the ice connection has been established successfully so we are ready to start the proxy
|
||||||
err = conn.startProxy(remoteConn, remoteWgPort)
|
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
|
log.Infof("connected to peer %s, proxy: %v, remote address: %s", conn.config.Key, conn.proxy != nil, remoteAddr.String())
|
||||||
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
|
||||||
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
|
||||||
// direct Wireguard connection
|
|
||||||
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
||||||
select {
|
select {
|
||||||
@@ -363,54 +372,58 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
|||||||
return candidate.Type() == ice.CandidateTypeRelay
|
return candidate.Type() == ice.CandidateTypeRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
var pair *ice.CandidatePair
|
|
||||||
pair, err := conn.agent.GetSelectedCandidatePair()
|
pair, err := conn.agent.GetSelectedCandidatePair()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
var endpoint net.Addr
|
||||||
p := conn.getProxy(pair, remoteWgPort)
|
if isRelayCandidate(pair.Local) {
|
||||||
conn.proxy = p
|
conn.proxy = NewWireGuardProxy(conn.config.WgConfig.WgListenPort, conn.config.WgConfig.RemoteKey, remoteConn)
|
||||||
err = p.Start(remoteConn)
|
endpoint, err = conn.proxy.Start()
|
||||||
|
if err != nil {
|
||||||
|
conn.proxy = nil
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
||||||
|
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
||||||
|
endpoint = remoteConn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpoint, conn.config.WgConfig.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if conn.proxy != nil {
|
||||||
|
_ = conn.proxy.Close()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.status = StatusConnected
|
conn.status = StatusConnected
|
||||||
|
|
||||||
peerState.ConnStatus = conn.status
|
peerState := State{
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
PubKey: conn.config.Key,
|
||||||
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
ConnStatus: conn.status,
|
||||||
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
ConnStatusUpdate: time.Now(),
|
||||||
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
|
Direct: conn.proxy == nil,
|
||||||
|
}
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||||
peerState.Relayed = true
|
peerState.Relayed = true
|
||||||
}
|
}
|
||||||
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
|
|
||||||
|
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("unable to save peer's state, got error: %v", err)
|
log.Warnf("unable to save peer's state, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return endpoint, nil
|
||||||
}
|
|
||||||
|
|
||||||
// todo rename this method and the proxy package to something more appropriate
|
|
||||||
func (conn *Conn) getProxy(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
|
||||||
if isRelayCandidate(pair.Local) {
|
|
||||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
|
||||||
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
|
||||||
|
|
||||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
@@ -439,20 +452,22 @@ func (conn *Conn) cleanup() error {
|
|||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
|
var err1, err2, err3 error
|
||||||
if conn.agent != nil {
|
if conn.agent != nil {
|
||||||
err := conn.agent.Close()
|
err1 = conn.agent.Close()
|
||||||
if err != nil {
|
if err1 == nil {
|
||||||
return err
|
conn.agent = nil
|
||||||
}
|
}
|
||||||
conn.agent = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: is it problem if we try to remove a peer what is never existed?
|
||||||
|
err2 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
|
||||||
if conn.proxy != nil {
|
if conn.proxy != nil {
|
||||||
err := conn.proxy.Close()
|
err3 = conn.proxy.Close()
|
||||||
if err != nil {
|
if err3 != nil {
|
||||||
return err
|
conn.proxy = nil
|
||||||
}
|
}
|
||||||
conn.proxy = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.notifyDisconnected != nil {
|
if conn.notifyDisconnected != nil {
|
||||||
@@ -462,10 +477,11 @@ func (conn *Conn) cleanup() error {
|
|||||||
|
|
||||||
conn.status = StatusDisconnected
|
conn.status = StatusDisconnected
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
peerState := State{
|
||||||
peerState.ConnStatus = conn.status
|
PubKey: conn.config.Key,
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatus: conn.status,
|
||||||
|
ConnStatusUpdate: time.Now(),
|
||||||
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
||||||
@@ -474,8 +490,13 @@ func (conn *Conn) cleanup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
||||||
|
if err1 != nil {
|
||||||
return nil
|
return err1
|
||||||
|
}
|
||||||
|
if err2 != nil {
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
return err3
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
package proxy
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
@@ -11,67 +12,45 @@ type WireGuardProxy struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
config Config
|
wgListenPort int
|
||||||
|
remoteKey string
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardProxy(config Config) *WireGuardProxy {
|
func NewWireGuardProxy(wgListenPort int, remoteKey string, remoteConn net.Conn) *WireGuardProxy {
|
||||||
p := &WireGuardProxy{config: config}
|
p := &WireGuardProxy{
|
||||||
|
wgListenPort: wgListenPort,
|
||||||
|
remoteKey: remoteKey,
|
||||||
|
remoteConn: remoteConn,
|
||||||
|
}
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireGuardProxy) updateEndpoint() error {
|
func (p *WireGuardProxy) Start() (net.Addr, error) {
|
||||||
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
lConn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", p.wgListenPort))
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// add local proxy connection as a Wireguard peer
|
|
||||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
|
||||||
udpAddr, p.config.PreSharedKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
|
|
||||||
p.remoteConn = remoteConn
|
|
||||||
|
|
||||||
var err error
|
|
||||||
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
err = p.updateEndpoint()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
p.localConn = lConn
|
||||||
|
|
||||||
go p.proxyToRemote()
|
go p.proxyToRemote()
|
||||||
go p.proxyToLocal()
|
go p.proxyToLocal()
|
||||||
|
|
||||||
return nil
|
return lConn.LocalAddr(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireGuardProxy) Close() error {
|
func (p *WireGuardProxy) Close() error {
|
||||||
p.cancel()
|
p.cancel()
|
||||||
if c := p.localConn; c != nil {
|
if p.localConn != nil {
|
||||||
err := p.localConn.Close()
|
err := p.localConn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,7 +62,7 @@ func (p *WireGuardProxy) proxyToRemote() {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey)
|
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.remoteKey)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
n, err := p.localConn.Read(buf)
|
n, err := p.localConn.Read(buf)
|
||||||
@@ -107,7 +86,7 @@ func (p *WireGuardProxy) proxyToLocal() {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey)
|
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.remoteKey)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
@@ -122,7 +101,3 @@ func (p *WireGuardProxy) proxyToLocal() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireGuardProxy) Type() Type {
|
|
||||||
return TypeWireGuard
|
|
||||||
}
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DummyProxy just sends pings to the RemoteKey peer and reads responses
|
|
||||||
type DummyProxy struct {
|
|
||||||
conn net.Conn
|
|
||||||
remote string
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDummyProxy(remote string) *DummyProxy {
|
|
||||||
p := &DummyProxy{remote: remote}
|
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Close() error {
|
|
||||||
p.cancel()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Start(remoteConn net.Conn) error {
|
|
||||||
p.conn = remoteConn
|
|
||||||
go func() {
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, err := p.conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, err := p.conn.Write([]byte("hello"))
|
|
||||||
//log.Debugf("sent ping to %s", p.remote)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Type() Type {
|
|
||||||
return TypeDummy
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NoProxy is used just to configure WireGuard without any local proxy in between.
|
|
||||||
// Used when the WireGuard interface is userspace and uses bind.ICEBind
|
|
||||||
type NoProxy struct {
|
|
||||||
config Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewNoProxy creates a new NoProxy with a provided config
|
|
||||||
func NewNoProxy(config Config) *NoProxy {
|
|
||||||
return &NoProxy{config: config}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close removes peer from the WireGuard interface
|
|
||||||
func (p *NoProxy) Close() error {
|
|
||||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start just updates WireGuard peer with the remote address
|
|
||||||
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
|
||||||
|
|
||||||
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
|
||||||
addr, p.config.PreSharedKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *NoProxy) Type() Type {
|
|
||||||
return TypeNoProxy
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DefaultWgKeepAlive = 25 * time.Second
|
|
||||||
|
|
||||||
type Type string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeDirectNoProxy Type = "DirectNoProxy"
|
|
||||||
TypeWireGuard Type = "WireGuard"
|
|
||||||
TypeDummy Type = "Dummy"
|
|
||||||
TypeNoProxy Type = "NoProxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
WgListenAddr string
|
|
||||||
RemoteKey string
|
|
||||||
WgInterface *iface.WGIface
|
|
||||||
AllowedIps string
|
|
||||||
PreSharedKey *wgtypes.Key
|
|
||||||
}
|
|
||||||
|
|
||||||
type Proxy interface {
|
|
||||||
io.Closer
|
|
||||||
// Start creates a local remoteConn and starts proxying data from/to remoteConn
|
|
||||||
Start(remoteConn net.Conn) error
|
|
||||||
Type() Type
|
|
||||||
}
|
|
||||||
@@ -77,12 +77,17 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
|
|
||||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||||
// Endpoint is optional
|
// Endpoint is optional
|
||||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint net.Addr, preSharedKey *wgtypes.Key) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
rAddr, err := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||||
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updating interface %s peer %s, endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
||||||
|
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, rAddr, preSharedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
|
|||||||
@@ -396,6 +396,12 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
|||||||
}
|
}
|
||||||
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||||
|
|
||||||
|
log.Infof("configuring IdpManagerConfig.OIDCConfig.Issuer with a new value %s,", oidcConfig.Issuer)
|
||||||
|
config.IdpManagerConfig.OIDCConfig.Issuer = strings.TrimRight(oidcConfig.Issuer, "/")
|
||||||
|
|
||||||
|
log.Infof("configuring IdpManagerConfig.OIDCConfig.TokenEndpoint with a new value %s,", oidcConfig.TokenEndpoint)
|
||||||
|
config.IdpManagerConfig.OIDCConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||||
|
|
||||||
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||||
oidcConfig.Issuer, config.HttpConfig.AuthIssuer)
|
oidcConfig.Issuer, config.HttpConfig.AuthIssuer)
|
||||||
config.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
config.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
||||||
@@ -441,7 +447,7 @@ type OIDCConfigResponse struct {
|
|||||||
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||||
res, err := http.Get(oidcEndpoint)
|
res, err := http.Get(oidcEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err)
|
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -49,16 +49,16 @@ type AccountManager interface {
|
|||||||
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||||
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
|
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
|
||||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||||
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
|
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||||
DeleteUser(accountID, executingUserID string, targetUserID string) error
|
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
||||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||||
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
|
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
MarkPATUsed(tokenID string) error
|
MarkPATUsed(tokenID string) error
|
||||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||||
AccountExists(accountId string) (*bool, error)
|
AccountExists(accountId string) (*bool, error)
|
||||||
GetPeerByKey(peerKey string) (*Peer, error)
|
GetPeerByKey(peerKey string) (*Peer, error)
|
||||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||||
@@ -69,10 +69,10 @@ type AccountManager interface {
|
|||||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||||
GetPeerNetwork(peerID string) (*Network, error)
|
GetPeerNetwork(peerID string) (*Network, error)
|
||||||
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
||||||
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
||||||
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
|
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||||
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||||
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||||
UpdatePeerSSHKey(peerID string, sshKey string) error
|
UpdatePeerSSHKey(peerID string, sshKey string) error
|
||||||
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
||||||
GetGroup(accountId, groupID string) (*Group, error)
|
GetGroup(accountId, groupID string) (*Group, error)
|
||||||
@@ -179,6 +179,7 @@ type UserInfo struct {
|
|||||||
AutoGroups []string `json:"auto_groups"`
|
AutoGroups []string `json:"auto_groups"`
|
||||||
Status string `json:"-"`
|
Status string `json:"-"`
|
||||||
IsServiceUser bool `json:"is_service_user"`
|
IsServiceUser bool `json:"is_service_user"`
|
||||||
|
IsBlocked bool `json:"is_blocked"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||||
@@ -902,7 +903,9 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
|
|||||||
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
||||||
users := make(map[string]struct{}, len(account.Users))
|
users := make(map[string]struct{}, len(account.Users))
|
||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
users[user.Id] = struct{}{}
|
if !user.IsServiceUser {
|
||||||
|
users[user.Id] = struct{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||||
userData, err := am.lookupCache(users, account.Id)
|
userData, err := am.lookupCache(users, account.Id)
|
||||||
|
|||||||
@@ -91,6 +91,10 @@ const (
|
|||||||
ServiceUserCreated
|
ServiceUserCreated
|
||||||
// ServiceUserDeleted indicates that a user deleted a service user
|
// ServiceUserDeleted indicates that a user deleted a service user
|
||||||
ServiceUserDeleted
|
ServiceUserDeleted
|
||||||
|
// UserBlocked indicates that a user blocked another user
|
||||||
|
UserBlocked
|
||||||
|
// UserUnblocked indicates that a user unblocked another user
|
||||||
|
UserUnblocked
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -184,6 +188,10 @@ const (
|
|||||||
ServiceUserCreatedMessage string = "Service user created"
|
ServiceUserCreatedMessage string = "Service user created"
|
||||||
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
|
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
|
||||||
ServiceUserDeletedMessage string = "Service user deleted"
|
ServiceUserDeletedMessage string = "Service user deleted"
|
||||||
|
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
|
||||||
|
UserBlockedMessage string = "User blocked"
|
||||||
|
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
||||||
|
UserUnblockedMessage string = "User unblocked"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Activity that triggered an Event
|
// Activity that triggered an Event
|
||||||
@@ -282,6 +290,10 @@ func (a Activity) Message() string {
|
|||||||
return ServiceUserCreatedMessage
|
return ServiceUserCreatedMessage
|
||||||
case ServiceUserDeleted:
|
case ServiceUserDeleted:
|
||||||
return ServiceUserDeletedMessage
|
return ServiceUserDeletedMessage
|
||||||
|
case UserBlocked:
|
||||||
|
return UserBlockedMessage
|
||||||
|
case UserUnblocked:
|
||||||
|
return UserUnblockedMessage
|
||||||
default:
|
default:
|
||||||
return "UNKNOWN_ACTIVITY"
|
return "UNKNOWN_ACTIVITY"
|
||||||
}
|
}
|
||||||
@@ -300,6 +312,10 @@ func (a Activity) StringCode() string {
|
|||||||
return "user.join"
|
return "user.join"
|
||||||
case UserInvited:
|
case UserInvited:
|
||||||
return "user.invite"
|
return "user.invite"
|
||||||
|
case UserBlocked:
|
||||||
|
return "user.block"
|
||||||
|
case UserUnblocked:
|
||||||
|
return "user.unblock"
|
||||||
case AccountCreated:
|
case AccountCreated:
|
||||||
return "account.create"
|
return "account.create"
|
||||||
case RuleAdded:
|
case RuleAdded:
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ components:
|
|||||||
status:
|
status:
|
||||||
description: User's status
|
description: User's status
|
||||||
type: string
|
type: string
|
||||||
enum: [ "active","invited","disabled" ]
|
enum: [ "active","invited","blocked" ]
|
||||||
auto_groups:
|
auto_groups:
|
||||||
description: Groups to auto-assign to peers registered by this user
|
description: Groups to auto-assign to peers registered by this user
|
||||||
type: array
|
type: array
|
||||||
@@ -79,6 +79,9 @@ components:
|
|||||||
description: Is true if this user is a service user
|
description: Is true if this user is a service user
|
||||||
type: boolean
|
type: boolean
|
||||||
readOnly: true
|
readOnly: true
|
||||||
|
is_blocked:
|
||||||
|
description: Is true if this user is blocked. Blocked users can't use the system
|
||||||
|
type: boolean
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- email
|
- email
|
||||||
@@ -86,6 +89,7 @@ components:
|
|||||||
- role
|
- role
|
||||||
- auto_groups
|
- auto_groups
|
||||||
- status
|
- status
|
||||||
|
- is_blocked
|
||||||
UserRequest:
|
UserRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -97,9 +101,13 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
|
is_blocked:
|
||||||
|
description: If set to true then user is blocked and can't use the system
|
||||||
|
type: boolean
|
||||||
required:
|
required:
|
||||||
- role
|
- role
|
||||||
- auto_groups
|
- auto_groups
|
||||||
|
- is_blocked
|
||||||
UserCreateRequest:
|
UserCreateRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -645,7 +653,7 @@ components:
|
|||||||
description: The string code of the activity that occurred during the event
|
description: The string code of the activity that occurred during the event
|
||||||
type: string
|
type: string
|
||||||
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
|
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
|
||||||
"user.role.update",
|
"user.role.update", "user.block", "user.unblock",
|
||||||
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
|
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
|
||||||
"setupkey.group.delete", "setupkey.group.add",
|
"setupkey.group.delete", "setupkey.group.add",
|
||||||
"rule.add", "rule.delete", "rule.update",
|
"rule.add", "rule.delete", "rule.update",
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ const (
|
|||||||
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
|
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
|
||||||
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
|
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
|
||||||
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
|
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
|
||||||
|
EventActivityCodeUserBlock EventActivityCode = "user.block"
|
||||||
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
|
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
|
||||||
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
|
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
|
||||||
EventActivityCodeUserInvite EventActivityCode = "user.invite"
|
EventActivityCodeUserInvite EventActivityCode = "user.invite"
|
||||||
@@ -53,6 +54,7 @@ const (
|
|||||||
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
|
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
|
||||||
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
|
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
|
||||||
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
|
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
|
||||||
|
EventActivityCodeUserUnblock EventActivityCode = "user.unblock"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for NameserverNsType.
|
// Defines values for NameserverNsType.
|
||||||
@@ -68,9 +70,9 @@ const (
|
|||||||
|
|
||||||
// Defines values for UserStatus.
|
// Defines values for UserStatus.
|
||||||
const (
|
const (
|
||||||
UserStatusActive UserStatus = "active"
|
UserStatusActive UserStatus = "active"
|
||||||
UserStatusDisabled UserStatus = "disabled"
|
UserStatusBlocked UserStatus = "blocked"
|
||||||
UserStatusInvited UserStatus = "invited"
|
UserStatusInvited UserStatus = "invited"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account defines model for Account.
|
// Account defines model for Account.
|
||||||
@@ -552,6 +554,9 @@ type User struct {
|
|||||||
// Id User ID
|
// Id User ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// IsBlocked Is true if this user is blocked. Blocked users can't use the system
|
||||||
|
IsBlocked bool `json:"is_blocked"`
|
||||||
|
|
||||||
// IsCurrent Is true if authenticated user is the same as this user
|
// IsCurrent Is true if authenticated user is the same as this user
|
||||||
IsCurrent *bool `json:"is_current,omitempty"`
|
IsCurrent *bool `json:"is_current,omitempty"`
|
||||||
|
|
||||||
@@ -594,6 +599,9 @@ type UserRequest struct {
|
|||||||
// AutoGroups Groups to auto-assign to peers registered by this user
|
// AutoGroups Groups to auto-assign to peers registered by this user
|
||||||
AutoGroups []string `json:"auto_groups"`
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
|
// IsBlocked If set to true then user is blocked and can't use the system
|
||||||
|
IsBlocked bool `json:"is_blocked"`
|
||||||
|
|
||||||
// Role User's NetBird account role
|
// Role User's NetBird account role
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
|
|||||||
acMiddleware := middleware.NewAccessControl(
|
acMiddleware := middleware.NewAccessControl(
|
||||||
authCfg.Audience,
|
authCfg.Audience,
|
||||||
authCfg.UserIDClaim,
|
authCfg.UserIDClaim,
|
||||||
accountManager.IsUserAdmin)
|
accountManager.GetUser)
|
||||||
|
|
||||||
rootRouter := mux.NewRouter()
|
rootRouter := mux.NewRouter()
|
||||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||||
|
|||||||
@@ -6,28 +6,30 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||||
|
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
|
|
||||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||||
type AccessControl struct {
|
type AccessControl struct {
|
||||||
isUserAdmin IsUserAdminFunc
|
|
||||||
claimsExtract jwtclaims.ClaimsExtractor
|
claimsExtract jwtclaims.ClaimsExtractor
|
||||||
|
getUser GetUser
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccessControl instance constructor
|
// NewAccessControl instance constructor
|
||||||
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
|
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
|
||||||
return &AccessControl{
|
return &AccessControl{
|
||||||
isUserAdmin: isUserAdmin,
|
|
||||||
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(audience),
|
jwtclaims.WithAudience(audience),
|
||||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||||
),
|
),
|
||||||
|
getUser: getUser,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,23 +39,29 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := a.claimsExtract.FromRequestContext(r)
|
claims := a.claimsExtract.FromRequestContext(r)
|
||||||
|
|
||||||
ok, err := a.isUserAdmin(claims)
|
user, err := a.getUser(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !ok {
|
|
||||||
|
if user.IsBlocked() {
|
||||||
|
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.IsAdmin() {
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||||
|
|
||||||
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
|
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Regex failed")
|
log.Debugf("regex failed")
|
||||||
util.WriteError(status.Errorf(status.Internal, ""), w)
|
util.WriteError(status.Errorf(status.Internal, ""), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if ok {
|
if ok {
|
||||||
log.Debugf("Valid Path")
|
log.Debugf("valid Path")
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ var testAccount = &server.Account{
|
|||||||
func initPATTestData() *PATHandler {
|
func initPATTestData() *PATHandler {
|
||||||
return &PATHandler{
|
return &PATHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||||
}
|
}
|
||||||
@@ -79,7 +79,7 @@ func initPATTestData() *PATHandler {
|
|||||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return testAccount, testAccount.Users[existingUserID], nil
|
return testAccount, testAccount.Users[existingUserID], nil
|
||||||
},
|
},
|
||||||
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||||
}
|
}
|
||||||
@@ -91,7 +91,7 @@ func initPATTestData() *PATHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func initPATTestData() *PATHandler {
|
|||||||
}
|
}
|
||||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||||
},
|
},
|
||||||
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.AutoGroups == nil {
|
||||||
|
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
userRole := server.StrRoleToUserRole(req.Role)
|
userRole := server.StrRoleToUserRole(req.Role)
|
||||||
if userRole == server.UserRoleUnknown {
|
if userRole == server.UserRoleUnknown {
|
||||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||||
@@ -71,7 +76,9 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
Id: userID,
|
Id: userID,
|
||||||
Role: userRole,
|
Role: userRole,
|
||||||
AutoGroups: req.AutoGroups,
|
AutoGroups: req.AutoGroups,
|
||||||
|
Blocked: req.IsBlocked,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
@@ -214,7 +221,11 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
|||||||
case "invited":
|
case "invited":
|
||||||
userStatus = api.UserStatusInvited
|
userStatus = api.UserStatusInvited
|
||||||
default:
|
default:
|
||||||
userStatus = api.UserStatusDisabled
|
userStatus = api.UserStatusBlocked
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsBlocked {
|
||||||
|
userStatus = api.UserStatusBlocked
|
||||||
}
|
}
|
||||||
|
|
||||||
isCurrent := user.ID == currenUserID
|
isCurrent := user.ID == currenUserID
|
||||||
@@ -227,5 +238,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
|||||||
Status: userStatus,
|
Status: userStatus,
|
||||||
IsCurrent: &isCurrent,
|
IsCurrent: &isCurrent,
|
||||||
IsServiceUser: &user.IsServiceUser,
|
IsServiceUser: &user.IsServiceUser,
|
||||||
|
IsBlocked: user.IsBlocked,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package http
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -31,16 +32,19 @@ var usersTestAccount = &server.Account{
|
|||||||
Id: existingUserID,
|
Id: existingUserID,
|
||||||
Role: "admin",
|
Role: "admin",
|
||||||
IsServiceUser: false,
|
IsServiceUser: false,
|
||||||
|
AutoGroups: []string{"group_1"},
|
||||||
},
|
},
|
||||||
regularUserID: {
|
regularUserID: {
|
||||||
Id: regularUserID,
|
Id: regularUserID,
|
||||||
Role: "user",
|
Role: "user",
|
||||||
IsServiceUser: false,
|
IsServiceUser: false,
|
||||||
|
AutoGroups: []string{"group_1"},
|
||||||
},
|
},
|
||||||
serviceUserID: {
|
serviceUserID: {
|
||||||
Id: serviceUserID,
|
Id: serviceUserID,
|
||||||
Role: "user",
|
Role: "user",
|
||||||
IsServiceUser: true,
|
IsServiceUser: true,
|
||||||
|
AutoGroups: []string{"group_1"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -70,7 +74,7 @@ func initUsersTestData() *UsersHandler {
|
|||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
},
|
},
|
||||||
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
|
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
if targetUserID == notFoundUserID {
|
if targetUserID == notFoundUserID {
|
||||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||||
}
|
}
|
||||||
@@ -79,6 +83,21 @@ func initUsersTestData() *UsersHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||||
|
if update.Id == notFoundUserID {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID != existingUserID {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := update.Copy().ToUserInfo(nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return info, nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
@@ -145,6 +164,122 @@ func TestGetUsers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateUser(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatusCode int
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
requestBody io.Reader
|
||||||
|
expectedUserID string
|
||||||
|
expectedRole string
|
||||||
|
expectedStatus string
|
||||||
|
expectedBlocked bool
|
||||||
|
expectedIsServiceUser bool
|
||||||
|
expectedGroups []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Update_Block_User",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/users/" + regularUserID,
|
||||||
|
expectedStatusCode: http.StatusOK,
|
||||||
|
expectedUserID: regularUserID,
|
||||||
|
expectedBlocked: true,
|
||||||
|
expectedRole: "user",
|
||||||
|
expectedStatus: "blocked",
|
||||||
|
expectedGroups: []string{"group_1"},
|
||||||
|
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update_Change_Role_To_Admin",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/users/" + regularUserID,
|
||||||
|
expectedStatusCode: http.StatusOK,
|
||||||
|
expectedUserID: regularUserID,
|
||||||
|
expectedBlocked: false,
|
||||||
|
expectedRole: "admin",
|
||||||
|
expectedStatus: "blocked",
|
||||||
|
expectedGroups: []string{"group_1"},
|
||||||
|
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update_Groups",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/users/" + regularUserID,
|
||||||
|
expectedStatusCode: http.StatusOK,
|
||||||
|
expectedUserID: regularUserID,
|
||||||
|
expectedBlocked: false,
|
||||||
|
expectedRole: "admin",
|
||||||
|
expectedStatus: "blocked",
|
||||||
|
expectedGroups: []string{"group_2", "group_3"},
|
||||||
|
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should_Fail_Because_AutoGroups_Is_Absent",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/users/" + regularUserID,
|
||||||
|
expectedStatusCode: http.StatusBadRequest,
|
||||||
|
expectedUserID: regularUserID,
|
||||||
|
expectedBlocked: false,
|
||||||
|
expectedRole: "admin",
|
||||||
|
expectedStatus: "blocked",
|
||||||
|
expectedGroups: []string{"group_2", "group_3"},
|
||||||
|
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
userHandler := initUsersTestData()
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if status := recorder.Code; status != tc.expectedStatusCode {
|
||||||
|
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||||
|
status, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectedStatusCode == 200 {
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("I don't know what I expected; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := &api.User{}
|
||||||
|
err = json.Unmarshal(content, &respBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("response content is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectedUserID, respBody.Id)
|
||||||
|
assert.Equal(t, tc.expectedRole, respBody.Role)
|
||||||
|
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
|
||||||
|
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
|
||||||
|
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
|
||||||
|
|
||||||
|
for _, expectedGroup := range tc.expectedGroups {
|
||||||
|
exists := false
|
||||||
|
for _, actualGroup := range respBody.AutoGroups {
|
||||||
|
if expectedGroup == actualGroup {
|
||||||
|
exists = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateUser(t *testing.T) {
|
func TestCreateUser(t *testing.T) {
|
||||||
name := "name"
|
name := "name"
|
||||||
email := "email"
|
email := "email"
|
||||||
|
|||||||
@@ -32,10 +32,10 @@ type Auth0Manager struct {
|
|||||||
// Auth0ClientConfig auth0 manager client configurations
|
// Auth0ClientConfig auth0 manager client configurations
|
||||||
type Auth0ClientConfig struct {
|
type Auth0ClientConfig struct {
|
||||||
Audience string
|
Audience string
|
||||||
AuthIssuer string
|
AuthIssuer string `json:"-"`
|
||||||
ClientID string
|
ClientID string
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
GrantType string
|
GrantType string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// auth0JWTRequest payload struct to request a JWT Token
|
// auth0JWTRequest payload struct to request a JWT Token
|
||||||
@@ -110,7 +110,8 @@ type auth0Profile struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewAuth0Manager creates a new instance of the Auth0Manager
|
// NewAuth0Manager creates a new instance of the Auth0Manager
|
||||||
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
func NewAuth0Manager(oidcConfig OIDCConfig, config Auth0ClientConfig,
|
||||||
|
appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
||||||
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
@@ -121,17 +122,19 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
config.AuthIssuer = oidcConfig.TokenEndpoint
|
||||||
|
config.GrantType = "client_credentials"
|
||||||
|
|
||||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.Audience == "" || config.AuthIssuer == "" {
|
if config.ClientID == "" {
|
||||||
return nil, fmt.Errorf("auth0 idp configuration is not complete")
|
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, clientID is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.GrantType != "client_credentials" {
|
if config.ClientSecret == "" {
|
||||||
return nil, fmt.Errorf("auth0 idp configuration failed. Grant Type should be client_credentials")
|
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(strings.ToLower(config.AuthIssuer), "https://") {
|
if config.Audience == "" {
|
||||||
return nil, fmt.Errorf("auth0 idp configuration failed. AuthIssuer should contain https://")
|
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
credentials := &Auth0Credentials{
|
credentials := &Auth0Credentials{
|
||||||
|
|||||||
@@ -459,26 +459,9 @@ func TestNewAuth0Manager(t *testing.T) {
|
|||||||
testCase3Config := defaultTestConfig
|
testCase3Config := defaultTestConfig
|
||||||
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
||||||
|
|
||||||
testCase3 := test{
|
for _, testCase := range []test{testCase1, testCase2} {
|
||||||
name: "Wrong Auth Issuer Format",
|
|
||||||
inputConfig: testCase3Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when wrong auth issuer format",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase4Config := defaultTestConfig
|
|
||||||
testCase4Config.GrantType = "spa"
|
|
||||||
|
|
||||||
testCase4 := test{
|
|
||||||
name: "Wrong Grant Type",
|
|
||||||
inputConfig: testCase4Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when wrong grant type",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
_, err := NewAuth0Manager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,12 +37,13 @@ type AzureManager struct {
|
|||||||
|
|
||||||
// AzureClientConfig azure manager client configurations.
|
// AzureClientConfig azure manager client configurations.
|
||||||
type AzureClientConfig struct {
|
type AzureClientConfig struct {
|
||||||
ClientID string
|
ClientID string
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
GraphAPIEndpoint string
|
ObjectID string
|
||||||
ObjectID string
|
|
||||||
TokenEndpoint string
|
GraphAPIEndpoint string `json:"-"`
|
||||||
GrantType string
|
TokenEndpoint string `json:"-"`
|
||||||
|
GrantType string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AzureCredentials azure authentication information.
|
// AzureCredentials azure authentication information.
|
||||||
@@ -74,7 +75,8 @@ type azureExtension struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewAzureManager creates a new instance of the AzureManager.
|
// NewAzureManager creates a new instance of the AzureManager.
|
||||||
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
func NewAzureManager(oidcConfig OIDCConfig, config AzureClientConfig,
|
||||||
|
appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
@@ -84,13 +86,20 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||||
|
config.GraphAPIEndpoint = "https://graph.microsoft.com"
|
||||||
|
config.GrantType = "client_credentials"
|
||||||
|
|
||||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.GraphAPIEndpoint == "" || config.TokenEndpoint == "" {
|
if config.ClientID == "" {
|
||||||
return nil, fmt.Errorf("azure idp configuration is not complete")
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.GrantType != "client_credentials" {
|
if config.ClientSecret == "" {
|
||||||
return nil, fmt.Errorf("azure idp configuration failed. Grant Type should be client_credentials")
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ObjectID == "" {
|
||||||
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
credentials := &AzureCredentials{
|
credentials := &AzureCredentials{
|
||||||
|
|||||||
@@ -19,12 +19,21 @@ type Manager interface {
|
|||||||
GetUserByEmail(email string) ([]*UserData, error)
|
GetUserByEmail(email string) ([]*UserData, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OIDCConfig specifies configuration for OpenID Connect provider
|
||||||
|
// These configurations are automatically loaded from the OIDC endpoint
|
||||||
|
type OIDCConfig struct {
|
||||||
|
Issuer string
|
||||||
|
TokenEndpoint string
|
||||||
|
}
|
||||||
|
|
||||||
// Config an idp configuration struct to be loaded from management server's config file
|
// Config an idp configuration struct to be loaded from management server's config file
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ManagerType string
|
ManagerType string
|
||||||
|
OIDCConfig OIDCConfig `json:"-"`
|
||||||
Auth0ClientCredentials Auth0ClientConfig
|
Auth0ClientCredentials Auth0ClientConfig
|
||||||
KeycloakClientCredentials KeycloakClientConfig
|
|
||||||
AzureClientCredentials AzureClientConfig
|
AzureClientCredentials AzureClientConfig
|
||||||
|
KeycloakClientCredentials KeycloakClientConfig
|
||||||
|
ZitadelClientCredentials ZitadelClientConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||||
@@ -73,11 +82,13 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
case "none", "":
|
case "none", "":
|
||||||
return nil, nil
|
return nil, nil
|
||||||
case "auth0":
|
case "auth0":
|
||||||
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
|
return NewAuth0Manager(config.OIDCConfig, config.Auth0ClientCredentials, appMetrics)
|
||||||
case "azure":
|
case "azure":
|
||||||
return NewAzureManager(config.AzureClientCredentials, appMetrics)
|
return NewAzureManager(config.OIDCConfig, config.AzureClientCredentials, appMetrics)
|
||||||
case "keycloak":
|
case "keycloak":
|
||||||
return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics)
|
return NewKeycloakManager(config.OIDCConfig, config.KeycloakClientCredentials, appMetrics)
|
||||||
|
case "zitadel":
|
||||||
|
return NewZitadelManager(config.OIDCConfig, config.ZitadelClientCredentials, appMetrics)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ type KeycloakClientConfig struct {
|
|||||||
ClientID string
|
ClientID string
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
AdminEndpoint string
|
AdminEndpoint string
|
||||||
TokenEndpoint string
|
TokenEndpoint string `json:"-"`
|
||||||
GrantType string
|
GrantType string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeycloakCredentials keycloak authentication information.
|
// KeycloakCredentials keycloak authentication information.
|
||||||
@@ -82,7 +82,8 @@ type keycloakProfile struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
||||||
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
func NewKeycloakManager(oidcConfig OIDCConfig, config KeycloakClientConfig,
|
||||||
|
appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
@@ -92,13 +93,19 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
|
|||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||||
|
config.GrantType = "client_credentials"
|
||||||
|
|
||||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" {
|
if config.ClientID == "" {
|
||||||
return nil, fmt.Errorf("keycloak idp configuration is not complete")
|
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.GrantType != "client_credentials" {
|
if config.ClientSecret == "" {
|
||||||
return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials")
|
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AdminEndpoint == "" {
|
||||||
|
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
credentials := &KeycloakCredentials{
|
credentials := &KeycloakCredentials{
|
||||||
|
|||||||
@@ -46,19 +46,19 @@ func TestNewKeycloakManager(t *testing.T) {
|
|||||||
assertErrFuncMessage: "should return error when field empty",
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
}
|
}
|
||||||
|
|
||||||
testCase5Config := defaultTestConfig
|
testCase3Config := defaultTestConfig
|
||||||
testCase5Config.GrantType = "authorization_code"
|
testCase3Config.ClientSecret = ""
|
||||||
|
|
||||||
testCase5 := test{
|
testCase3 := test{
|
||||||
name: "Wrong GrantType",
|
name: "Missing ClientSecret Configuration",
|
||||||
inputConfig: testCase5Config,
|
inputConfig: testCase3Config,
|
||||||
assertErrFunc: require.Error,
|
assertErrFunc: require.Error,
|
||||||
assertErrFuncMessage: "should return error when wrong grant type",
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase5} {
|
for _, testCase := range []test{testCase1, testCase2, testCase3} {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
_, err := NewKeycloakManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
609
management/server/idp/zitadel.go
Normal file
609
management/server/idp/zitadel.go
Normal file
@@ -0,0 +1,609 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ZitadelManager zitadel manager client instance.
|
||||||
|
type ZitadelManager struct {
|
||||||
|
managementEndpoint string
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
credentials ManagerCredentials
|
||||||
|
helper ManagerHelper
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZitadelClientConfig zitadel manager client configurations.
|
||||||
|
type ZitadelClientConfig struct {
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
GrantType string `json:"-"`
|
||||||
|
TokenEndpoint string `json:"-"`
|
||||||
|
ManagementEndpoint string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZitadelCredentials zitadel authentication information.
|
||||||
|
type ZitadelCredentials struct {
|
||||||
|
clientConfig ZitadelClientConfig
|
||||||
|
helper ManagerHelper
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
jwtToken JWTToken
|
||||||
|
mux sync.Mutex
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelEmail specifies details of a user email.
|
||||||
|
type zitadelEmail struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
IsEmailVerified bool `json:"isEmailVerified"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelUserInfo specifies user information.
|
||||||
|
type zitadelUserInfo struct {
|
||||||
|
FirstName string `json:"firstName"`
|
||||||
|
LastName string `json:"lastName"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelUser specifies profile details for user account.
|
||||||
|
type zitadelUser struct {
|
||||||
|
UserName string `json:"userName,omitempty"`
|
||||||
|
Profile zitadelUserInfo `json:"profile"`
|
||||||
|
Email zitadelEmail `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type zitadelAttributes map[string][]map[string]any
|
||||||
|
|
||||||
|
// zitadelMetadata holds additional user data.
|
||||||
|
type zitadelMetadata struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelProfile represents an zitadel user profile response.
|
||||||
|
type zitadelProfile struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
State string `json:"state"`
|
||||||
|
UserName string `json:"userName"`
|
||||||
|
PreferredLoginName string `json:"preferredLoginName"`
|
||||||
|
LoginNames []string `json:"loginNames"`
|
||||||
|
Human *zitadelUser `json:"human"`
|
||||||
|
Metadata []zitadelMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewZitadelManager creates a new instance of the ZitadelManager.
|
||||||
|
func NewZitadelManager(oidcConfig OIDCConfig, config ZitadelClientConfig,
|
||||||
|
appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
helper := JsonParser{}
|
||||||
|
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||||
|
config.ManagementEndpoint = fmt.Sprintf("%s/management/v1", oidcConfig.Issuer)
|
||||||
|
config.GrantType = "client_credentials"
|
||||||
|
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, clientID is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ClientSecret == "" {
|
||||||
|
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, ClientSecret is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := &ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: httpClient,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ZitadelManager{
|
||||||
|
managementEndpoint: config.ManagementEndpoint,
|
||||||
|
httpClient: httpClient,
|
||||||
|
credentials: credentials,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from zitadel.
|
||||||
|
func (zc *ZitadelCredentials) jwtStillValid() bool {
|
||||||
|
return !zc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(zc.jwtToken.expiresInTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestJWTToken performs request to get jwt token.
|
||||||
|
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", zc.clientConfig.ClientID)
|
||||||
|
data.Set("client_secret", zc.clientConfig.ClientSecret)
|
||||||
|
data.Set("grant_type", zc.clientConfig.GrantType)
|
||||||
|
data.Set("scope", "urn:zitadel:iam:org:project:id:zitadel:aud")
|
||||||
|
|
||||||
|
payload := strings.NewReader(data.Encode())
|
||||||
|
req, err := http.NewRequest(http.MethodPost, zc.clientConfig.TokenEndpoint, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
log.Debug("requesting new jwt token for zitadel idp manager")
|
||||||
|
|
||||||
|
resp, err := zc.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if zc.appMetrics != nil {
|
||||||
|
zc.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds.
|
||||||
|
func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||||
|
jwtToken := JWTToken{}
|
||||||
|
body, err := io.ReadAll(rawBody)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = zc.helper.Unmarshal(body, &jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||||
|
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exp maps into exp from jwt token
|
||||||
|
var IssuedAt struct{ Exp int64 }
|
||||||
|
err = zc.helper.Unmarshal(data, &IssuedAt)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||||
|
|
||||||
|
return jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate retrieves access token to use the Zitadel Management API.
|
||||||
|
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
zc.mux.Lock()
|
||||||
|
defer zc.mux.Unlock()
|
||||||
|
|
||||||
|
if zc.appMetrics != nil {
|
||||||
|
zc.appMetrics.IDPMetrics().CountAuthenticate()
|
||||||
|
}
|
||||||
|
|
||||||
|
// reuse the token without requesting a new one if it is not expired,
|
||||||
|
// and if expiry time is sufficient time available to make a request.
|
||||||
|
if zc.jwtStillValid() {
|
||||||
|
return zc.jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := zc.requestJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
return zc.jwtToken, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
jwtToken, err := zc.parseRequestJWTResponse(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return zc.jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
zc.jwtToken = jwtToken
|
||||||
|
|
||||||
|
return zc.jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user in zitadel Idp and sends an invite.
|
||||||
|
func (zm *ZitadelManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
|
||||||
|
payload, err := buildZitadelCreateUserRequestPayload(email, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := zm.post("users/human/_import", payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountCreateUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
}
|
||||||
|
err = zm.helper.Unmarshal(body, &result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
invite := true
|
||||||
|
appMetadata := AppMetadata{
|
||||||
|
WTAccountID: accountID,
|
||||||
|
WTPendingInvite: &invite,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add metadata to new user
|
||||||
|
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return zm.GetUserDataByID(result.UserID, appMetadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByEmail searches users with a given email.
|
||||||
|
// If no users have been found, this function returns an empty list.
|
||||||
|
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||||
|
searchByEmail := zitadelAttributes{
|
||||||
|
"queries": {
|
||||||
|
{
|
||||||
|
"emailQuery": map[string]any{
|
||||||
|
"emailAddress": email,
|
||||||
|
"method": "TEXT_QUERY_METHOD_EQUALS",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload, err := zm.helper.Marshal(searchByEmail)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := zm.post("users/_search", string(payload))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profiles struct{ Result []zitadelProfile }
|
||||||
|
err = zm.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0)
|
||||||
|
for _, profile := range profiles.Result {
|
||||||
|
metadata, err := zm.getUserMetadata(profile.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
profile.Metadata = metadata
|
||||||
|
|
||||||
|
users = append(users, profile.userData())
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserDataByID requests user data from zitadel via ID.
|
||||||
|
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||||
|
body, err := zm.get("users/"+userID, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profile struct{ User zitadelProfile }
|
||||||
|
err = zm.helper.Unmarshal(body, &profile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata, err := zm.getUserMetadata(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
profile.User.Metadata = metadata
|
||||||
|
|
||||||
|
return profile.User.userData(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount returns all the users for a given profile.
|
||||||
|
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||||
|
accounts, err := zm.GetAllAccounts()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountGetAccount()
|
||||||
|
}
|
||||||
|
|
||||||
|
return accounts[accountID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||||
|
// It returns a list of users indexed by accountID.
|
||||||
|
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||||
|
body, err := zm.post("users/_search", "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profiles struct{ Result []zitadelProfile }
|
||||||
|
err = zm.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
indexedUsers := make(map[string][]*UserData)
|
||||||
|
for _, profile := range profiles.Result {
|
||||||
|
// fetch user metadata
|
||||||
|
metadata, err := zm.getUserMetadata(profile.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
profile.Metadata = metadata
|
||||||
|
|
||||||
|
userData := profile.userData()
|
||||||
|
accountID := userData.AppMetadata.WTAccountID
|
||||||
|
|
||||||
|
if accountID != "" {
|
||||||
|
if _, ok := indexedUsers[accountID]; !ok {
|
||||||
|
indexedUsers[accountID] = make([]*UserData, 0)
|
||||||
|
}
|
||||||
|
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return indexedUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||||
|
// Metadata values are base64 encoded.
|
||||||
|
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||||
|
if appMetadata.WTPendingInvite == nil {
|
||||||
|
appMetadata.WTPendingInvite = new(bool)
|
||||||
|
}
|
||||||
|
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
|
||||||
|
|
||||||
|
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
|
||||||
|
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
|
||||||
|
|
||||||
|
metadata := zitadelAttributes{
|
||||||
|
"metadata": {
|
||||||
|
{
|
||||||
|
"key": wtAccountID,
|
||||||
|
"value": wtAccountIDValue,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": wtPendingInvite,
|
||||||
|
"value": wtPendingInviteValue,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload, err := zm.helper.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
|
||||||
|
_, err = zm.post(resource, string(payload))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserMetadata requests user metadata from zitadel via ID.
|
||||||
|
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
||||||
|
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
||||||
|
body, err := zm.post(resource, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var metadata struct{ Result []zitadelMetadata }
|
||||||
|
err = zm.helper.Unmarshal(body, &metadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata.Result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// post perform Post requests.
|
||||||
|
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||||
|
jwtToken, err := zm.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
|
||||||
|
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := zm.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// get perform Get requests.
|
||||||
|
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
||||||
|
jwtToken, err := zm.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/%s?%s", zm.managementEndpoint, resource, q.Encode())
|
||||||
|
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := zm.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
if zm.appMetrics != nil {
|
||||||
|
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// value returns string represented by the base64 string value.
|
||||||
|
func (zm zitadelMetadata) value() string {
|
||||||
|
value, err := base64.StdEncoding.DecodeString(zm.Value)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// userData construct user data from zitadel profile.
|
||||||
|
func (zp zitadelProfile) userData() *UserData {
|
||||||
|
var (
|
||||||
|
email string
|
||||||
|
name string
|
||||||
|
wtAccountIDValue string
|
||||||
|
wtPendingInviteValue bool
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, metadata := range zp.Metadata {
|
||||||
|
if metadata.Key == wtAccountID {
|
||||||
|
wtAccountIDValue = metadata.value()
|
||||||
|
}
|
||||||
|
|
||||||
|
if metadata.Key == wtPendingInvite {
|
||||||
|
value, err := strconv.ParseBool(metadata.value())
|
||||||
|
if err == nil {
|
||||||
|
wtPendingInviteValue = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain the email for the human account and the login name,
|
||||||
|
// for the machine account.
|
||||||
|
if zp.Human != nil {
|
||||||
|
email = zp.Human.Email.Email
|
||||||
|
name = zp.Human.Profile.DisplayName
|
||||||
|
} else {
|
||||||
|
if len(zp.LoginNames) > 0 {
|
||||||
|
email = zp.LoginNames[0]
|
||||||
|
name = zp.LoginNames[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserData{
|
||||||
|
Email: email,
|
||||||
|
Name: name,
|
||||||
|
ID: zp.ID,
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: wtAccountIDValue,
|
||||||
|
WTPendingInvite: &wtPendingInviteValue,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
|
||||||
|
var firstName, lastName string
|
||||||
|
|
||||||
|
words := strings.Fields(name)
|
||||||
|
if n := len(words); n > 0 {
|
||||||
|
firstName = strings.Join(words[:n-1], " ")
|
||||||
|
lastName = words[n-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &zitadelUser{
|
||||||
|
UserName: name,
|
||||||
|
Profile: zitadelUserInfo{
|
||||||
|
FirstName: strings.TrimSpace(firstName),
|
||||||
|
LastName: strings.TrimSpace(lastName),
|
||||||
|
DisplayName: name,
|
||||||
|
},
|
||||||
|
Email: zitadelEmail{
|
||||||
|
Email: email,
|
||||||
|
IsEmailVerified: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
str, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(str), nil
|
||||||
|
}
|
||||||
486
management/server/idp/zitadel_test.go
Normal file
486
management/server/idp/zitadel_test.go
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewZitadelManager(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
inputConfig ZitadelClientConfig
|
||||||
|
assertErrFunc require.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultTestConfig := ZitadelClientConfig{
|
||||||
|
ClientID: "client_id",
|
||||||
|
ClientSecret: "client_secret",
|
||||||
|
GrantType: "client_credentials",
|
||||||
|
TokenEndpoint: "http://localhost/oauth/v2/token",
|
||||||
|
ManagementEndpoint: "http://localhost/management/v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase1 := test{
|
||||||
|
name: "Good Configuration",
|
||||||
|
inputConfig: defaultTestConfig,
|
||||||
|
assertErrFunc: require.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase2Config := defaultTestConfig
|
||||||
|
testCase2Config.ClientID = ""
|
||||||
|
|
||||||
|
testCase2 := test{
|
||||||
|
name: "Missing ClientID Configuration",
|
||||||
|
inputConfig: testCase2Config,
|
||||||
|
assertErrFunc: require.Error,
|
||||||
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase3Config := defaultTestConfig
|
||||||
|
testCase3Config.ClientSecret = ""
|
||||||
|
|
||||||
|
testCase3 := test{
|
||||||
|
name: "Missing ClientSecret Configuration",
|
||||||
|
inputConfig: testCase3Config,
|
||||||
|
assertErrFunc: require.Error,
|
||||||
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []test{testCase1, testCase2, testCase3} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
_, err := NewZitadelManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockZitadelCredentials struct {
|
||||||
|
jwtToken JWTToken
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
return mc.jwtToken, mc.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelRequestJWTToken(t *testing.T) {
|
||||||
|
|
||||||
|
type requestJWTTokenTest struct {
|
||||||
|
name string
|
||||||
|
inputCode int
|
||||||
|
inputRespBody string
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedFuncExitErrDiff error
|
||||||
|
expectedToken string
|
||||||
|
}
|
||||||
|
exp := 5
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||||
|
name: "Good JWT Response",
|
||||||
|
inputCode: 200,
|
||||||
|
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: token,
|
||||||
|
}
|
||||||
|
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||||
|
name: "Request Bad Status Code",
|
||||||
|
inputCode: 400,
|
||||||
|
inputRespBody: "{}",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
jwtReqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputRespBody,
|
||||||
|
code: testCase.inputCode,
|
||||||
|
}
|
||||||
|
config := ZitadelClientConfig{}
|
||||||
|
|
||||||
|
creds := ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: &jwtReqClient,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := creds.requestJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
if testCase.expectedFuncExitErrDiff != nil {
|
||||||
|
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err, "unable to read the response body")
|
||||||
|
|
||||||
|
jwtToken := JWTToken{}
|
||||||
|
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||||
|
assert.NoError(t, err, "unable to parse the json input")
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelParseRequestJWTResponse(t *testing.T) {
|
||||||
|
type parseRequestJWTResponseTest struct {
|
||||||
|
name string
|
||||||
|
inputRespBody string
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedToken string
|
||||||
|
expectedExpiresIn int
|
||||||
|
assertErrFunc assert.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
exp := 100
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||||
|
name: "Parse Good JWT Body",
|
||||||
|
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: token,
|
||||||
|
expectedExpiresIn: exp,
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "no error was expected",
|
||||||
|
}
|
||||||
|
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||||
|
name: "Parse Bad json JWT Body",
|
||||||
|
inputRespBody: "",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: "",
|
||||||
|
expectedExpiresIn: 0,
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "json error was expected",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||||
|
config := ZitadelClientConfig{}
|
||||||
|
|
||||||
|
creds := ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelJwtStillValid(t *testing.T) {
|
||||||
|
type jwtStillValidTest struct {
|
||||||
|
name string
|
||||||
|
inputTime time.Time
|
||||||
|
expectedResult bool
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||||
|
name: "JWT still valid",
|
||||||
|
inputTime: time.Now().Add(10 * time.Second),
|
||||||
|
expectedResult: true,
|
||||||
|
message: "should be true",
|
||||||
|
}
|
||||||
|
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||||
|
name: "JWT is invalid",
|
||||||
|
inputTime: time.Now(),
|
||||||
|
expectedResult: false,
|
||||||
|
message: "should be false",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
config := ZitadelClientConfig{}
|
||||||
|
|
||||||
|
creds := ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
}
|
||||||
|
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelAuthenticate(t *testing.T) {
|
||||||
|
type authenticateTest struct {
|
||||||
|
name string
|
||||||
|
inputCode int
|
||||||
|
inputResBody string
|
||||||
|
inputExpireToken time.Time
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedFuncExitErrDiff error
|
||||||
|
expectedCode int
|
||||||
|
expectedToken string
|
||||||
|
}
|
||||||
|
exp := 5
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
authenticateTestCase1 := authenticateTest{
|
||||||
|
name: "Get Cached token",
|
||||||
|
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: nil,
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateTestCase2 := authenticateTest{
|
||||||
|
name: "Get Good JWT Response",
|
||||||
|
inputCode: 200,
|
||||||
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateTestCase3 := authenticateTest{
|
||||||
|
name: "Get Bad Status Code",
|
||||||
|
inputCode: 400,
|
||||||
|
inputResBody: "{}",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
jwtReqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputResBody,
|
||||||
|
code: testCase.inputCode,
|
||||||
|
}
|
||||||
|
config := ZitadelClientConfig{}
|
||||||
|
|
||||||
|
creds := ZitadelCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: &jwtReqClient,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||||
|
|
||||||
|
_, err := creds.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
if testCase.expectedFuncExitErrDiff != nil {
|
||||||
|
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
|
||||||
|
type updateUserAppMetadataTest struct {
|
||||||
|
name string
|
||||||
|
inputReqBody string
|
||||||
|
expectedReqBody string
|
||||||
|
appMetadata AppMetadata
|
||||||
|
statusCode int
|
||||||
|
helper ManagerHelper
|
||||||
|
managerCreds ManagerCredentials
|
||||||
|
assertErrFunc assert.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Authentication",
|
||||||
|
expectedReqBody: "",
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 400,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockZitadelCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
err: fmt.Errorf("error"),
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Response Parsing",
|
||||||
|
statusCode: 400,
|
||||||
|
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||||
|
managerCreds: &mockZitadelCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||||
|
name: "Good request",
|
||||||
|
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 200,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockZitadelCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
invite := true
|
||||||
|
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||||
|
name: "Update Pending Invite",
|
||||||
|
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
|
||||||
|
appMetadata: AppMetadata{
|
||||||
|
WTAccountID: "ok",
|
||||||
|
WTPendingInvite: &invite,
|
||||||
|
},
|
||||||
|
statusCode: 200,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockZitadelCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||||
|
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
reqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputReqBody,
|
||||||
|
code: testCase.statusCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
httpClient: &reqClient,
|
||||||
|
credentials: testCase.managerCreds,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelProfile(t *testing.T) {
|
||||||
|
type azureProfileTest struct {
|
||||||
|
name string
|
||||||
|
invite bool
|
||||||
|
inputProfile zitadelProfile
|
||||||
|
expectedUserData UserData
|
||||||
|
}
|
||||||
|
|
||||||
|
azureProfileTestCase1 := azureProfileTest{
|
||||||
|
name: "User Request",
|
||||||
|
invite: false,
|
||||||
|
inputProfile: zitadelProfile{
|
||||||
|
ID: "test1",
|
||||||
|
State: "USER_STATE_ACTIVE",
|
||||||
|
UserName: "test1@mail.com",
|
||||||
|
PreferredLoginName: "test1@mail.com",
|
||||||
|
LoginNames: []string{
|
||||||
|
"test1@mail.com",
|
||||||
|
},
|
||||||
|
Human: &zitadelUser{
|
||||||
|
Profile: zitadelUserInfo{
|
||||||
|
FirstName: "ZITADEL",
|
||||||
|
LastName: "Admin",
|
||||||
|
DisplayName: "ZITADEL Admin",
|
||||||
|
},
|
||||||
|
Email: zitadelEmail{
|
||||||
|
Email: "test1@mail.com",
|
||||||
|
IsEmailVerified: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Metadata: []zitadelMetadata{
|
||||||
|
{
|
||||||
|
Key: "wt_account_id",
|
||||||
|
Value: "MQ==",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "wt_pending_invite",
|
||||||
|
Value: "ZmFsc2U=",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUserData: UserData{
|
||||||
|
ID: "test1",
|
||||||
|
Name: "ZITADEL Admin",
|
||||||
|
Email: "test1@mail.com",
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
azureProfileTestCase2 := azureProfileTest{
|
||||||
|
name: "Service User Request",
|
||||||
|
invite: true,
|
||||||
|
inputProfile: zitadelProfile{
|
||||||
|
ID: "test2",
|
||||||
|
State: "USER_STATE_ACTIVE",
|
||||||
|
UserName: "machine",
|
||||||
|
PreferredLoginName: "machine",
|
||||||
|
LoginNames: []string{
|
||||||
|
"machine",
|
||||||
|
},
|
||||||
|
Human: nil,
|
||||||
|
Metadata: []zitadelMetadata{
|
||||||
|
{
|
||||||
|
Key: "wt_account_id",
|
||||||
|
Value: "MQ==",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "wt_pending_invite",
|
||||||
|
Value: "dHJ1ZQ==",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUserData: UserData{
|
||||||
|
ID: "test2",
|
||||||
|
Name: "machine",
|
||||||
|
Email: "machine",
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||||
|
userData := testCase.inputProfile.userData()
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
|
||||||
|
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,7 +19,7 @@ type MockAccountManager struct {
|
|||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||||
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
AccountExistsFunc func(accountId string) (*bool, error)
|
AccountExistsFunc func(accountId string) (*bool, error)
|
||||||
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
||||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||||
@@ -60,11 +60,11 @@ type MockAccountManager struct {
|
|||||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||||
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
|
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||||
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||||
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
|
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||||
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||||
GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
@@ -190,33 +190,33 @@ func (am *MockAccountManager) MarkPATUsed(pat string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
||||||
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||||
if am.CreatePATFunc != nil {
|
if am.CreatePATFunc != nil {
|
||||||
return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn)
|
return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
|
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
|
||||||
func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
if am.DeletePATFunc != nil {
|
if am.DeletePATFunc != nil {
|
||||||
return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID)
|
return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||||
if am.GetPATFunc != nil {
|
if am.GetPATFunc != nil {
|
||||||
return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID)
|
return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||||
if am.GetAllPATsFunc != nil {
|
if am.GetAllPATsFunc != nil {
|
||||||
return am.GetAllPATsFunc(accountID, executingUserID, targetUserID)
|
return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
|
||||||
}
|
}
|
||||||
@@ -385,12 +385,12 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
|
|||||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
|
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
|
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||||
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||||
if am.IsUserAdminFunc != nil {
|
if am.GetUserFunc != nil {
|
||||||
return am.IsUserAdminFunc(claims)
|
return am.GetUserFunc(claims)
|
||||||
}
|
}
|
||||||
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
||||||
@@ -493,9 +493,9 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
|
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
if am.DeleteUserFunc != nil {
|
if am.DeleteUserFunc != nil {
|
||||||
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
|
return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -605,6 +605,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er
|
|||||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if peerLoginExpired(peer, account) {
|
if peerLoginExpired(peer, account) {
|
||||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||||
}
|
}
|
||||||
@@ -644,6 +649,11 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
|||||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
updateRemotePeers := false
|
updateRemotePeers := false
|
||||||
if peerLoginExpired(peer, account) {
|
if peerLoginExpired(peer, account) {
|
||||||
err = checkAuth(login.UserID, peer)
|
err = checkAuth(login.UserID, peer)
|
||||||
@@ -676,6 +686,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
|||||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error {
|
||||||
|
if peer.AddedWithSSOLogin() {
|
||||||
|
user, err := account.FindUser(peer.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.PermissionDenied, "user doesn't exist")
|
||||||
|
}
|
||||||
|
if user.IsBlocked() {
|
||||||
|
return status.Errorf(status.PermissionDenied, "user is blocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func checkAuth(loginUserID string, peer *Peer) error {
|
func checkAuth(loginUserID string, peer *Peer) error {
|
||||||
if loginUserID == "" {
|
if loginUserID == "" {
|
||||||
// absence of a user ID indicates that JWT wasn't provided.
|
// absence of a user ID indicates that JWT wasn't provided.
|
||||||
|
|||||||
@@ -51,15 +51,22 @@ type User struct {
|
|||||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||||
AutoGroups []string
|
AutoGroups []string
|
||||||
PATs map[string]*PersonalAccessToken
|
PATs map[string]*PersonalAccessToken
|
||||||
|
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||||
|
Blocked bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAdmin returns true if user is an admin, false otherwise
|
// IsBlocked returns true if the user is blocked, false otherwise
|
||||||
|
func (u *User) IsBlocked() bool {
|
||||||
|
return u.Blocked
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAdmin returns true if the user is an admin, false otherwise
|
||||||
func (u *User) IsAdmin() bool {
|
func (u *User) IsAdmin() bool {
|
||||||
return u.Role == UserRoleAdmin
|
return u.Role == UserRoleAdmin
|
||||||
}
|
}
|
||||||
|
|
||||||
// toUserInfo converts a User object to a UserInfo object.
|
// ToUserInfo converts a User object to a UserInfo object.
|
||||||
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||||
autoGroups := u.AutoGroups
|
autoGroups := u.AutoGroups
|
||||||
if autoGroups == nil {
|
if autoGroups == nil {
|
||||||
autoGroups = []string{}
|
autoGroups = []string{}
|
||||||
@@ -74,6 +81,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
|||||||
AutoGroups: u.AutoGroups,
|
AutoGroups: u.AutoGroups,
|
||||||
Status: string(UserStatusActive),
|
Status: string(UserStatusActive),
|
||||||
IsServiceUser: u.IsServiceUser,
|
IsServiceUser: u.IsServiceUser,
|
||||||
|
IsBlocked: u.Blocked,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
if userData.ID != u.Id {
|
if userData.ID != u.Id {
|
||||||
@@ -93,6 +101,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
|||||||
AutoGroups: autoGroups,
|
AutoGroups: autoGroups,
|
||||||
Status: string(userStatus),
|
Status: string(userStatus),
|
||||||
IsServiceUser: u.IsServiceUser,
|
IsServiceUser: u.IsServiceUser,
|
||||||
|
IsBlocked: u.Blocked,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,6 +122,7 @@ func (u *User) Copy() *User {
|
|||||||
IsServiceUser: u.IsServiceUser,
|
IsServiceUser: u.IsServiceUser,
|
||||||
ServiceUserName: u.ServiceUserName,
|
ServiceUserName: u.ServiceUserName,
|
||||||
PATs: pats,
|
PATs: pats,
|
||||||
|
Blocked: u.Blocked,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +148,7 @@ func NewAdminUser(id string) *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createServiceUser creates a new service user under the given account.
|
// createServiceUser creates a new service user under the given account.
|
||||||
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -147,7 +157,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
|
|||||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if executingUser == nil {
|
if executingUser == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
@@ -166,7 +176,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": newUser.ServiceUserName}
|
meta := map[string]any{"name": newUser.ServiceUserName}
|
||||||
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
am.storeEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||||
|
|
||||||
return &UserInfo{
|
return &UserInfo{
|
||||||
ID: newUser.Id,
|
ID: newUser.Id,
|
||||||
@@ -212,7 +222,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
|||||||
}
|
}
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||||
}
|
}
|
||||||
|
|
||||||
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
||||||
@@ -221,7 +231,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(users) > 0 {
|
if len(users) > 0 {
|
||||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||||
}
|
}
|
||||||
|
|
||||||
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
||||||
@@ -249,12 +259,27 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
|||||||
|
|
||||||
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
|
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
|
||||||
|
|
||||||
return newUser.toUserInfo(idpUser)
|
return newUser.ToUserInfo(idpUser)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUser looks up a user by provided authorization claims.
|
||||||
|
// It will also create an account if didn't exist for this user before.
|
||||||
|
func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||||
|
account, _, err := am.GetAccountFromToken(claims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := account.Users[claims.UserId]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteUser deletes a user from the given account.
|
// DeleteUser deletes a user from the given account.
|
||||||
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
|
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -268,7 +293,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
|||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if executingUser == nil {
|
if executingUser == nil {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
@@ -281,7 +306,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": targetUser.ServiceUserName}
|
meta := map[string]any{"name": targetUser.ServiceUserName}
|
||||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||||
|
|
||||||
delete(account.Users, targetUserID)
|
delete(account.Users, targetUserID)
|
||||||
|
|
||||||
@@ -294,7 +319,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreatePAT creates a new PAT for the given user
|
// CreatePAT creates a new PAT for the given user
|
||||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -316,12 +341,12 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
|||||||
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if targetUser == nil {
|
if targetUser == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,13 +363,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||||
|
|
||||||
return pat, nil
|
return pat, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePAT deletes a specific PAT from a user
|
// DeletePAT deletes a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -358,12 +383,12 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
|||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if targetUser == nil {
|
if targetUser == nil {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,7 +407,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||||
|
|
||||||
delete(targetUser.PATs, tokenID)
|
delete(targetUser.PATs, tokenID)
|
||||||
|
|
||||||
@@ -394,7 +419,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPAT returns a specific PAT from a user
|
// GetPAT returns a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -408,12 +433,12 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
|||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if targetUser == nil {
|
if targetUser == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -426,7 +451,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPATs returns all PATs for a user
|
// GetAllPATs returns all PATs for a user
|
||||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -440,12 +465,12 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
|||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser := account.Users[executingUserID]
|
executingUser := account.Users[initiatorUserID]
|
||||||
if targetUser == nil {
|
if targetUser == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,9 +482,9 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
|||||||
return pats, nil
|
return pats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
|
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
|
||||||
// Only User.AutoGroups field is allowed to be updated for now.
|
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||||
func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) {
|
func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -472,56 +497,102 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
initiatorUser, err := account.FindUser(initiatorUserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !initiatorUser.IsAdmin() || initiatorUser.IsBlocked() {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "only admins are authorized to perform user update operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldUser := account.Users[update.Id]
|
||||||
|
if oldUser == nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||||
|
}
|
||||||
|
|
||||||
|
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && update.Role != UserRoleAdmin {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||||
|
}
|
||||||
|
|
||||||
|
// only auto groups, revoked status, and name can be updated for now
|
||||||
|
newUser := oldUser.Copy()
|
||||||
|
newUser.Role = update.Role
|
||||||
|
newUser.Blocked = update.Blocked
|
||||||
|
|
||||||
for _, newGroupID := range update.AutoGroups {
|
for _, newGroupID := range update.AutoGroups {
|
||||||
if _, ok := account.Groups[newGroupID]; !ok {
|
if _, ok := account.Groups[newGroupID]; !ok {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||||
newGroupID, update.Id)
|
newGroupID, update.Id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
oldUser := account.Users[update.Id]
|
|
||||||
if oldUser == nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "update not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// only auto groups, revoked status, and name can be updated for now
|
|
||||||
newUser := oldUser.Copy()
|
|
||||||
newUser.AutoGroups = update.AutoGroups
|
newUser.AutoGroups = update.AutoGroups
|
||||||
newUser.Role = update.Role
|
|
||||||
|
|
||||||
account.Users[newUser.Id] = newUser
|
account.Users[newUser.Id] = newUser
|
||||||
|
|
||||||
|
if !oldUser.IsBlocked() && update.IsBlocked() {
|
||||||
|
// expire peers that belong to the user who's getting blocked
|
||||||
|
blockedPeers, err := account.FindUserPeers(update.Id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var peerIDs []string
|
||||||
|
for _, peer := range blockedPeers {
|
||||||
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
|
peer.MarkLoginExpired(true)
|
||||||
|
account.UpdatePeer(peer)
|
||||||
|
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||||
|
err = am.updateAccountPeers(account)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
|
||||||
|
return nil, err
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveAccount(account); err != nil {
|
if err = am.Store.SaveAccount(account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
// store activity logs
|
||||||
if oldUser.Role != newUser.Role {
|
if oldUser.Role != newUser.Role {
|
||||||
am.storeEvent(userID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||||
}
|
}
|
||||||
|
|
||||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
if update.AutoGroups != nil {
|
||||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||||
for _, g := range removedGroups {
|
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||||
group := account.GetGroup(g)
|
for _, g := range removedGroups {
|
||||||
if group != nil {
|
group := account.GetGroup(g)
|
||||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
if group != nil {
|
||||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||||
} else {
|
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
} else {
|
||||||
|
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
for _, g := range addedGroups {
|
||||||
|
group := account.GetGroup(g)
|
||||||
for _, g := range addedGroups {
|
if group != nil {
|
||||||
group := account.GetGroup(g)
|
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||||
if group != nil {
|
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
}
|
||||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
|
||||||
} else {
|
|
||||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
||||||
@@ -532,9 +603,9 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
|||||||
if userData == nil {
|
if userData == nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
|
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
|
||||||
}
|
}
|
||||||
return newUser.toUserInfo(userData)
|
return newUser.ToUserInfo(userData)
|
||||||
}
|
}
|
||||||
return newUser.toUserInfo(nil)
|
return newUser.ToUserInfo(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||||
@@ -574,21 +645,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
|
|
||||||
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
|
||||||
account, _, err := am.GetAccountFromToken(claims)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get account: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user, ok := account.Users[claims.UserId]
|
|
||||||
if !ok {
|
|
||||||
return false, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return user.Role == UserRoleAdmin, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
|
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
|
||||||
// based on provided user role.
|
// based on provided user role.
|
||||||
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
|
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
|
||||||
@@ -625,7 +681,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
|||||||
// if user is not an admin then show only current user and do not show other users
|
// if user is not an admin then show only current user and do not show other users
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
info, err := accountUser.toUserInfo(nil)
|
info, err := accountUser.ToUserInfo(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -642,7 +698,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
|||||||
|
|
||||||
var info *UserInfo
|
var info *UserInfo
|
||||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||||
info, err = localUser.toUserInfo(queriedUser)
|
info, err = localUser.ToUserInfo(queriedUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@@ -265,6 +266,7 @@ func TestUser_Copy(t *testing.T) {
|
|||||||
LastUsed: time.Now(),
|
LastUsed: time.Now(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Blocked: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := validateStruct(user)
|
err := validateStruct(user)
|
||||||
@@ -288,7 +290,7 @@ func validateStruct(s interface{}) (err error) {
|
|||||||
field := structVal.Field(i)
|
field := structVal.Field(i)
|
||||||
fieldName := structType.Field(i).Name
|
fieldName := structType.Field(i).Name
|
||||||
|
|
||||||
isSet := field.IsValid() && !field.IsZero()
|
isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool")
|
||||||
|
|
||||||
if !isSet {
|
if !isSet {
|
||||||
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
|
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
|
||||||
@@ -440,7 +442,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
|||||||
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||||
|
|
||||||
@@ -458,42 +460,23 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
|||||||
UserId: mockUserID,
|
UserId: mockUserID,
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := am.IsUserAdmin(claims)
|
user, err := am.GetUser(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when checking user role: %s", err)
|
t.Fatalf("Error when checking user role: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.Equal(t, mockUserID, user.Id)
|
||||||
|
assert.True(t, user.IsAdmin())
|
||||||
|
assert.False(t, user.IsBlocked())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
|
func TestUser_IsAdmin(t *testing.T) {
|
||||||
store := newStore(t)
|
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
|
||||||
account.Users[mockUserID] = &User{
|
|
||||||
Id: mockUserID,
|
|
||||||
Role: "user",
|
|
||||||
}
|
|
||||||
|
|
||||||
err := store.SaveAccount(account)
|
user := NewAdminUser(mockUserID)
|
||||||
if err != nil {
|
assert.True(t, user.IsAdmin())
|
||||||
t.Fatalf("Error when saving account: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
am := DefaultAccountManager{
|
user = NewRegularUser(mockUserID)
|
||||||
Store: store,
|
assert.False(t, user.IsAdmin())
|
||||||
eventStore: &activity.InMemoryEventStore{},
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := jwtclaims.AuthorizationClaims{
|
|
||||||
UserId: mockUserID,
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := am.IsUserAdmin(claims)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error when checking user role: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||||
@@ -550,3 +533,103 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
|||||||
assert.Equal(t, 1, len(users))
|
assert.Equal(t, 1, len(users))
|
||||||
assert.Equal(t, mockServiceUserID, users[0].ID)
|
assert.Equal(t, mockServiceUserID, users[0].ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
regularUserID := "regularUser"
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
adminInitiator bool
|
||||||
|
update *User
|
||||||
|
expectedErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should_Fail_To_Update_Admin_Role",
|
||||||
|
expectedErr: true,
|
||||||
|
adminInitiator: true,
|
||||||
|
update: &User{
|
||||||
|
Id: userID,
|
||||||
|
Role: UserRoleUser,
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "Should_Fail_When_Admin_Blocks_Themselves",
|
||||||
|
expectedErr: true,
|
||||||
|
adminInitiator: true,
|
||||||
|
update: &User{
|
||||||
|
Id: userID,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should_Fail_To_Update_Non_Existing_User",
|
||||||
|
expectedErr: true,
|
||||||
|
adminInitiator: true,
|
||||||
|
update: &User{
|
||||||
|
Id: userID,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
|
||||||
|
expectedErr: true,
|
||||||
|
adminInitiator: false,
|
||||||
|
update: &User{
|
||||||
|
Id: userID,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should_Update_User",
|
||||||
|
expectedErr: false,
|
||||||
|
adminInitiator: true,
|
||||||
|
update: &User{
|
||||||
|
Id: regularUserID,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
|
||||||
|
// create an account and an admin user
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a regular user
|
||||||
|
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
||||||
|
err = manager.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
initiatorID := userID
|
||||||
|
if !tc.adminInitiator {
|
||||||
|
initiatorID = regularUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := manager.SaveUser(account.Id, initiatorID, tc.update)
|
||||||
|
if tc.expectedErr {
|
||||||
|
require.Errorf(t, err, "expecting SaveUser to throw an error")
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err, "expecting SaveUser not to throw an error")
|
||||||
|
assert.NotNil(t, updated)
|
||||||
|
|
||||||
|
assert.Equal(t, string(tc.update.Role), updated.Role)
|
||||||
|
assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ download_release_binary() {
|
|||||||
echo "Installing $1 from $DOWNLOAD_URL"
|
echo "Installing $1 from $DOWNLOAD_URL"
|
||||||
cd /tmp && curl -LO "$DOWNLOAD_URL"
|
cd /tmp && curl -LO "$DOWNLOAD_URL"
|
||||||
|
|
||||||
|
|
||||||
if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then
|
if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then
|
||||||
INSTALL_DIR="/Applications/NetBird UI.app"
|
INSTALL_DIR="/Applications/NetBird UI.app"
|
||||||
|
|
||||||
@@ -43,8 +44,9 @@ download_release_binary() {
|
|||||||
unzip -q -o "$BINARY_NAME"
|
unzip -q -o "$BINARY_NAME"
|
||||||
mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR"
|
mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR"
|
||||||
else
|
else
|
||||||
|
sudo mkdir -p "$INSTALL_DIR"
|
||||||
tar -xzvf "$BINARY_NAME"
|
tar -xzvf "$BINARY_NAME"
|
||||||
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR"
|
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,44 +2,44 @@ package sharedsock
|
|||||||
|
|
||||||
import "golang.org/x/net/bpf"
|
import "golang.org/x/net/bpf"
|
||||||
|
|
||||||
// STUNFilter implements BPFFilter by filtering on STUN packets.
|
// IncomingSTUNFilter implements BPFFilter and filters out anything but incoming STUN packets to a specified destination port.
|
||||||
// Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard).
|
// Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard).
|
||||||
type STUNFilter struct {
|
type IncomingSTUNFilter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSTUNFilter creates an instance of a STUNFilter
|
// NewIncomingSTUNFilter creates an instance of a IncomingSTUNFilter
|
||||||
func NewSTUNFilter() BPFFilter {
|
func NewIncomingSTUNFilter() BPFFilter {
|
||||||
return &STUNFilter{}
|
return &IncomingSTUNFilter{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetInstructions returns raw BPF instructions for ipv4 and ipv6 that filter out anything but STUN packets
|
// GetInstructions returns raw BPF instructions for ipv4 and ipv6 that filter out anything but STUN packets
|
||||||
func (sf STUNFilter) GetInstructions(port uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) {
|
func (filter *IncomingSTUNFilter) GetInstructions(dstPort uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) {
|
||||||
raw4, err = rawInstructions(22, 32, port)
|
raw4, err = rawInstructions(22, 32, dstPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
raw6, err = rawInstructions(2, 12, port)
|
raw6, err = rawInstructions(2, 12, dstPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return raw4, raw6, nil
|
return raw4, raw6, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func rawInstructions(portOff, cookieOff, port uint32) ([]bpf.RawInstruction, error) {
|
func rawInstructions(dstPortOff, cookieOff, dstPort uint32) ([]bpf.RawInstruction, error) {
|
||||||
// UDP raw socket for ipv4 receives the rcvdPacket with IP headers
|
// UDP raw socket for ipv4 receives the rcvdPacket with IP headers
|
||||||
// UDP raw socket for ipv6 receives the rcvdPacket with UDP headers
|
// UDP raw socket for ipv6 receives the rcvdPacket with UDP headers
|
||||||
instructions := []bpf.Instruction{
|
instructions := []bpf.Instruction{
|
||||||
// Load the source port from the UDP header (offset 22 for ipv4 and 2 for ipv6)
|
// Load the destination port from the UDP header (offset 22 for ipv4 and 2 for ipv6)
|
||||||
bpf.LoadAbsolute{Off: portOff, Size: 2},
|
bpf.LoadAbsolute{Off: dstPortOff, Size: 2},
|
||||||
// Check if the source port is equal to the specified `port`. If not, skip the next 3 instructions.
|
// Check if the destination port is equal to the specified `dstPort`. If not, skip the next 3 instructions.
|
||||||
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: port, SkipTrue: 3},
|
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: dstPort, SkipTrue: 3},
|
||||||
// Load the 4-byte value (magic cookie) from the UDP payload (offset 32 for ipv4 and 12 for ipv6)
|
// Load the 4-byte value (magic cookie) from the UDP payload (offset 32 for ipv4 and 12 for ipv6)
|
||||||
bpf.LoadAbsolute{Off: cookieOff, Size: 4},
|
bpf.LoadAbsolute{Off: cookieOff, Size: 4},
|
||||||
// Check if the loaded value is equal to the `magicCookie`. If not, skip the next instruction.
|
// Check if the loaded value is equal to the `magicCookie`. If not, skip the next instruction.
|
||||||
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: magicCookie, SkipTrue: 1},
|
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: magicCookie, SkipTrue: 1},
|
||||||
// If both the port and the magic cookie match, return a positive value (0xffffffff)
|
// If both the dstPort and the magic cookie match, return a positive value (0xffffffff)
|
||||||
bpf.RetConstant{Val: 0xffffffff},
|
bpf.RetConstant{Val: 0xffffffff},
|
||||||
// If either the port or the magic cookie doesn't match, return 0
|
// If either the dstPort or the magic cookie doesn't match, return 0
|
||||||
bpf.RetConstant{Val: 0},
|
bpf.RetConstant{Val: 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
package sharedsock
|
package sharedsock
|
||||||
|
|
||||||
// NewSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux
|
// NewIncomingSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux
|
||||||
func NewSTUNFilter() BPFFilter {
|
func NewIncomingSTUNFilter() BPFFilter {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ import (
|
|||||||
var ErrSharedSockStopped = fmt.Errorf("shared socked stopped")
|
var ErrSharedSockStopped = fmt.Errorf("shared socked stopped")
|
||||||
|
|
||||||
// SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered
|
// SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered
|
||||||
// by BPF instructions (e.g., STUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)).
|
// by BPF instructions (e.g., IncomingSTUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)).
|
||||||
|
// It is meant to be used when sharing a port with some other process.
|
||||||
type SharedSocket struct {
|
type SharedSocket struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
conn4 *socket.Conn
|
conn4 *socket.Conn
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) {
|
|||||||
|
|
||||||
// create raw socket on a port
|
// create raw socket on a port
|
||||||
testingPort := 51821
|
testingPort := 51821
|
||||||
rawSock, err := Listen(testingPort, NewSTUNFilter())
|
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||||
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||||
err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second))
|
err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||||
require.NoError(t, err, "unable to set deadline, error: %s", err)
|
require.NoError(t, err, "unable to set deadline, error: %s", err)
|
||||||
@@ -76,7 +76,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) {
|
|||||||
|
|
||||||
func TestShouldNotReadNonSTUNPackets(t *testing.T) {
|
func TestShouldNotReadNonSTUNPackets(t *testing.T) {
|
||||||
testingPort := 39439
|
testingPort := 39439
|
||||||
rawSock, err := Listen(testingPort, NewSTUNFilter())
|
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||||
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||||
defer rawSock.Close()
|
defer rawSock.Close()
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ func TestWriteTo(t *testing.T) {
|
|||||||
defer udpListener.Close()
|
defer udpListener.Close()
|
||||||
|
|
||||||
testingPort := 39440
|
testingPort := 39440
|
||||||
rawSock, err := Listen(testingPort, NewSTUNFilter())
|
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||||
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||||
defer rawSock.Close()
|
defer rawSock.Close()
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ func TestWriteTo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSharedSocket_Close(t *testing.T) {
|
func TestSharedSocket_Close(t *testing.T) {
|
||||||
rawSock, err := Listen(39440, NewSTUNFilter())
|
rawSock, err := Listen(39440, NewIncomingSTUNFilter())
|
||||||
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||||
|
|
||||||
errGrp := errgroup.Group{}
|
errGrp := errgroup.Group{}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux
|
//go:build !linux || android
|
||||||
|
|
||||||
package sharedsock
|
package sharedsock
|
||||||
|
|
||||||
|
|||||||
@@ -350,7 +350,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
|||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
|
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||||
|
|
||||||
decryptedMessage, err := c.decryptMessage(msg)
|
decryptedMessage, err := c.decryptMessage(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user