mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
12 Commits
v0.19.0
...
separate_p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b43d7e8ef | ||
|
|
dcc83c8741 | ||
|
|
d56669ec2e | ||
|
|
e3d2b6a408 | ||
|
|
9f758b2015 | ||
|
|
2c50d7af1e | ||
|
|
e4c28f64fa | ||
|
|
6f2c4078ef | ||
|
|
f4ec1699ca | ||
|
|
fea53b2f0f | ||
|
|
60e6d0890a | ||
|
|
cb12e2da21 |
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
@@ -45,12 +46,16 @@ var loginCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
PreSharedKey: &preSharedKey,
|
||||
})
|
||||
}
|
||||
if preSharedKey != "" {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
@@ -106,7 +111,7 @@ var loginCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
openURL(cmd, loginResp.VerificationURIComplete)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
if err != nil {
|
||||
@@ -185,7 +190,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||
@@ -199,11 +204,16 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete string) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
var codeMsg string
|
||||
if !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
|
||||
err := open.Run(verificationURIComplete)
|
||||
cmd.Printf("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
" " + verificationURIComplete + " \n\n")
|
||||
" " + verificationURIComplete + " " + codeMsg + " \n\n")
|
||||
if err != nil {
|
||||
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
||||
}
|
||||
|
||||
@@ -78,14 +78,18 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
PreSharedKey: &preSharedKey,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
})
|
||||
}
|
||||
if preSharedKey != "" {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
@@ -172,7 +176,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
if err != nil {
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -210,7 +209,7 @@ func (e *Engine) Start() error {
|
||||
e.udpMux = udpMux
|
||||
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
|
||||
} else {
|
||||
rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewSTUNFilter())
|
||||
rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewIncomingSTUNFilter())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -247,7 +246,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
for _, p := range peersUpdate {
|
||||
peerPubKey := p.GetWgPubKey()
|
||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
||||
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") {
|
||||
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
||||
modified = append(modified, p)
|
||||
continue
|
||||
}
|
||||
@@ -758,9 +757,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
||||
|
||||
// we might have received new STUN and TURN servers meanwhile, so update them
|
||||
e.syncMsgMux.Lock()
|
||||
conf := conn.GetConf()
|
||||
conf.StunTurn = append(e.STUNs, e.TURNs...)
|
||||
conn.UpdateConf(conf)
|
||||
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
err := conn.Open()
|
||||
@@ -789,9 +786,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
|
||||
proxyConfig := proxy.Config{
|
||||
wgConfig := peer.WgConfig{
|
||||
RemoteKey: pubKey,
|
||||
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
|
||||
WgListenPort: e.config.WgPort,
|
||||
WgInterface: e.wgInterface,
|
||||
AllowedIps: allowedIPs,
|
||||
PreSharedKey: e.config.PreSharedKey,
|
||||
@@ -808,7 +805,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
||||
Timeout: timeout,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
ProxyConfig: proxyConfig,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||
|
||||
@@ -148,6 +148,11 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
||||
}
|
||||
|
||||
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
||||
if deviceCode.VerificationURIComplete == "" {
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
@@ -28,8 +28,18 @@ const (
|
||||
|
||||
iceKeepAliveDefault = 4 * time.Second
|
||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
WgListenPort int
|
||||
RemoteKey string
|
||||
WgInterface *iface.WGIface
|
||||
AllowedIps string
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
||||
// ConnConfig is a peer Connection configuration
|
||||
type ConnConfig struct {
|
||||
|
||||
@@ -48,7 +58,7 @@ type ConnConfig struct {
|
||||
|
||||
Timeout time.Duration
|
||||
|
||||
ProxyConfig proxy.Config
|
||||
WgConfig WgConfig
|
||||
|
||||
UDPMux ice.UDPMux
|
||||
UDPMuxSrflx ice.UniversalUDPMux
|
||||
@@ -103,7 +113,7 @@ type Conn struct {
|
||||
|
||||
statusRecorder *Status
|
||||
|
||||
proxy proxy.Proxy
|
||||
proxy *WireGuardProxy
|
||||
remoteModeCh chan ModeMessage
|
||||
meta meta
|
||||
|
||||
@@ -127,9 +137,14 @@ func (conn *Conn) GetConf() ConnConfig {
|
||||
return conn.config
|
||||
}
|
||||
|
||||
// UpdateConf updates the connection config
|
||||
func (conn *Conn) UpdateConf(conf ConnConfig) {
|
||||
conn.config = conf
|
||||
// WgConfig returns the WireGuard config
|
||||
func (conn *Conn) WgConfig() WgConfig {
|
||||
return conn.config.WgConfig
|
||||
}
|
||||
|
||||
// UpdateStunTurn update the turn and stun addresses
|
||||
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
|
||||
conn.config.StunTurn = turnStun
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
@@ -240,12 +255,12 @@ func readICEAgentConfigProperties() (time.Duration, time.Duration) {
|
||||
func (conn *Conn) Open() error {
|
||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
|
||||
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
peerState.ConnStatus = conn.status
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatus: conn.status,
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||
@@ -300,10 +315,11 @@ func (conn *Conn) Open() error {
|
||||
defer conn.notifyDisconnected()
|
||||
conn.mu.Unlock()
|
||||
|
||||
peerState = State{PubKey: conn.config.Key}
|
||||
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
peerState = State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||
@@ -334,19 +350,12 @@ func (conn *Conn) Open() error {
|
||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||
}
|
||||
// the ice connection has been established successfully so we are ready to start the proxy
|
||||
err = conn.startProxy(remoteConn, remoteWgPort)
|
||||
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
|
||||
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
||||
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
||||
// direct Wireguard connection
|
||||
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
|
||||
} else {
|
||||
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
|
||||
}
|
||||
log.Infof("connected to peer %s, proxy: %v, remote address: %s", conn.config.Key, conn.proxy != nil, remoteAddr.String())
|
||||
|
||||
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
||||
select {
|
||||
@@ -363,54 +372,58 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
var pair *ice.CandidatePair
|
||||
pair, err := conn.agent.GetSelectedCandidatePair()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
p := conn.getProxy(pair, remoteWgPort)
|
||||
conn.proxy = p
|
||||
err = p.Start(remoteConn)
|
||||
var endpoint net.Addr
|
||||
if isRelayCandidate(pair.Local) {
|
||||
conn.proxy = NewWireGuardProxy(conn.config.WgConfig.WgListenPort, conn.config.WgConfig.RemoteKey, remoteConn)
|
||||
endpoint, err = conn.proxy.Start()
|
||||
if err != nil {
|
||||
conn.proxy = nil
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
||||
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
||||
endpoint = remoteConn.RemoteAddr()
|
||||
}
|
||||
|
||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpoint, conn.config.WgConfig.PreSharedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
if conn.proxy != nil {
|
||||
_ = conn.proxy.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.status = StatusConnected
|
||||
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
||||
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
Direct: conn.proxy == nil,
|
||||
}
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
peerState.Relayed = true
|
||||
}
|
||||
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
|
||||
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
log.Warnf("unable to save peer's state, got error: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// todo rename this method and the proxy package to something more appropriate
|
||||
func (conn *Conn) getProxy(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
||||
if isRelayCandidate(pair.Local) {
|
||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
||||
}
|
||||
|
||||
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
||||
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
||||
|
||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
@@ -439,20 +452,22 @@ func (conn *Conn) cleanup() error {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
var err1, err2, err3 error
|
||||
if conn.agent != nil {
|
||||
err := conn.agent.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
err1 = conn.agent.Close()
|
||||
if err1 == nil {
|
||||
conn.agent = nil
|
||||
}
|
||||
conn.agent = nil
|
||||
}
|
||||
|
||||
// todo: is it problem if we try to remove a peer what is never existed?
|
||||
err2 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
|
||||
if conn.proxy != nil {
|
||||
err := conn.proxy.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
err3 = conn.proxy.Close()
|
||||
if err3 != nil {
|
||||
conn.proxy = nil
|
||||
}
|
||||
conn.proxy = nil
|
||||
}
|
||||
|
||||
if conn.notifyDisconnected != nil {
|
||||
@@ -462,10 +477,11 @@ func (conn *Conn) cleanup() error {
|
||||
|
||||
conn.status = StatusDisconnected
|
||||
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
||||
@@ -474,8 +490,13 @@ func (conn *Conn) cleanup() error {
|
||||
}
|
||||
|
||||
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
||||
|
||||
return nil
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
return err3
|
||||
}
|
||||
|
||||
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package proxy
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
)
|
||||
@@ -11,67 +12,45 @@ type WireGuardProxy struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
config Config
|
||||
wgListenPort int
|
||||
remoteKey string
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
}
|
||||
|
||||
func NewWireGuardProxy(config Config) *WireGuardProxy {
|
||||
p := &WireGuardProxy{config: config}
|
||||
func NewWireGuardProxy(wgListenPort int, remoteKey string, remoteConn net.Conn) *WireGuardProxy {
|
||||
p := &WireGuardProxy{
|
||||
wgListenPort: wgListenPort,
|
||||
remoteKey: remoteKey,
|
||||
remoteConn: remoteConn,
|
||||
}
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *WireGuardProxy) updateEndpoint() error {
|
||||
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// add local proxy connection as a Wireguard peer
|
||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
udpAddr, p.config.PreSharedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
|
||||
func (p *WireGuardProxy) Start() (net.Addr, error) {
|
||||
lConn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", p.wgListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.updateEndpoint()
|
||||
if err != nil {
|
||||
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
p.localConn = lConn
|
||||
|
||||
go p.proxyToRemote()
|
||||
go p.proxyToLocal()
|
||||
|
||||
return nil
|
||||
return lConn.LocalAddr(), nil
|
||||
}
|
||||
|
||||
func (p *WireGuardProxy) Close() error {
|
||||
p.cancel()
|
||||
if c := p.localConn; c != nil {
|
||||
if p.localConn != nil {
|
||||
err := p.localConn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -83,7 +62,7 @@ func (p *WireGuardProxy) proxyToRemote() {
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey)
|
||||
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.remoteKey)
|
||||
return
|
||||
default:
|
||||
n, err := p.localConn.Read(buf)
|
||||
@@ -107,7 +86,7 @@ func (p *WireGuardProxy) proxyToLocal() {
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey)
|
||||
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.remoteKey)
|
||||
return
|
||||
default:
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
@@ -122,7 +101,3 @@ func (p *WireGuardProxy) proxyToLocal() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WireGuardProxy) Type() Type {
|
||||
return TypeWireGuard
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DummyProxy just sends pings to the RemoteKey peer and reads responses
|
||||
type DummyProxy struct {
|
||||
conn net.Conn
|
||||
remote string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewDummyProxy(remote string) *DummyProxy {
|
||||
p := &DummyProxy{remote: remote}
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DummyProxy) Close() error {
|
||||
p.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DummyProxy) Start(remoteConn net.Conn) error {
|
||||
p.conn = remoteConn
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
_, err := p.conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
|
||||
return
|
||||
}
|
||||
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
_, err := p.conn.Write([]byte("hello"))
|
||||
//log.Debugf("sent ping to %s", p.remote)
|
||||
if err != nil {
|
||||
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DummyProxy) Type() Type {
|
||||
return TypeDummy
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
)
|
||||
|
||||
// NoProxy is used just to configure WireGuard without any local proxy in between.
|
||||
// Used when the WireGuard interface is userspace and uses bind.ICEBind
|
||||
type NoProxy struct {
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewNoProxy creates a new NoProxy with a provided config
|
||||
func NewNoProxy(config Config) *NoProxy {
|
||||
return &NoProxy{config: config}
|
||||
}
|
||||
|
||||
// Close removes peer from the WireGuard interface
|
||||
func (p *NoProxy) Close() error {
|
||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start just updates WireGuard peer with the remote address
|
||||
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
||||
|
||||
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
|
||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
addr, p.config.PreSharedKey)
|
||||
}
|
||||
|
||||
func (p *NoProxy) Type() Type {
|
||||
return TypeNoProxy
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const DefaultWgKeepAlive = 25 * time.Second
|
||||
|
||||
type Type string
|
||||
|
||||
const (
|
||||
TypeDirectNoProxy Type = "DirectNoProxy"
|
||||
TypeWireGuard Type = "WireGuard"
|
||||
TypeDummy Type = "Dummy"
|
||||
TypeNoProxy Type = "NoProxy"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WgListenAddr string
|
||||
RemoteKey string
|
||||
WgInterface *iface.WGIface
|
||||
AllowedIps string
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
||||
type Proxy interface {
|
||||
io.Closer
|
||||
// Start creates a local remoteConn and starts proxying data from/to remoteConn
|
||||
Start(remoteConn net.Conn) error
|
||||
Type() Type
|
||||
}
|
||||
@@ -77,12 +77,17 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
|
||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||
// Endpoint is optional
|
||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint net.Addr, preSharedKey *wgtypes.Key) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
||||
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
rAddr, err := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("updating interface %s peer %s, endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
|
||||
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, rAddr, preSharedKey)
|
||||
}
|
||||
|
||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||
|
||||
@@ -396,6 +396,12 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
}
|
||||
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||
|
||||
log.Infof("configuring IdpManagerConfig.OIDCConfig.Issuer with a new value %s,", oidcConfig.Issuer)
|
||||
config.IdpManagerConfig.OIDCConfig.Issuer = strings.TrimRight(oidcConfig.Issuer, "/")
|
||||
|
||||
log.Infof("configuring IdpManagerConfig.OIDCConfig.TokenEndpoint with a new value %s,", oidcConfig.TokenEndpoint)
|
||||
config.IdpManagerConfig.OIDCConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
|
||||
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||
oidcConfig.Issuer, config.HttpConfig.AuthIssuer)
|
||||
config.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
||||
@@ -441,7 +447,7 @@ type OIDCConfigResponse struct {
|
||||
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
res, err := http.Get(oidcEndpoint)
|
||||
if err != nil {
|
||||
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err)
|
||||
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
|
||||
@@ -49,16 +49,16 @@ type AccountManager interface {
|
||||
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
|
||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(accountID, executingUserID string, targetUserID string) error
|
||||
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
|
||||
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
MarkPATUsed(tokenID string) error
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeerByKey(peerKey string) (*Peer, error)
|
||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||
@@ -69,10 +69,10 @@ type AccountManager interface {
|
||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||
GetPeerNetwork(peerID string) (*Network, error)
|
||||
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
||||
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||
GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||
UpdatePeerSSHKey(peerID string, sshKey string) error
|
||||
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
||||
GetGroup(accountId, groupID string) (*Group, error)
|
||||
@@ -179,6 +179,7 @@ type UserInfo struct {
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
}
|
||||
|
||||
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
@@ -902,7 +903,9 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
|
||||
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
||||
users := make(map[string]struct{}, len(account.Users))
|
||||
for _, user := range account.Users {
|
||||
users[user.Id] = struct{}{}
|
||||
if !user.IsServiceUser {
|
||||
users[user.Id] = struct{}{}
|
||||
}
|
||||
}
|
||||
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||
userData, err := am.lookupCache(users, account.Id)
|
||||
|
||||
@@ -91,6 +91,10 @@ const (
|
||||
ServiceUserCreated
|
||||
// ServiceUserDeleted indicates that a user deleted a service user
|
||||
ServiceUserDeleted
|
||||
// UserBlocked indicates that a user blocked another user
|
||||
UserBlocked
|
||||
// UserUnblocked indicates that a user unblocked another user
|
||||
UserUnblocked
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -184,6 +188,10 @@ const (
|
||||
ServiceUserCreatedMessage string = "Service user created"
|
||||
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
|
||||
ServiceUserDeletedMessage string = "Service user deleted"
|
||||
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
|
||||
UserBlockedMessage string = "User blocked"
|
||||
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
||||
UserUnblockedMessage string = "User unblocked"
|
||||
)
|
||||
|
||||
// Activity that triggered an Event
|
||||
@@ -282,6 +290,10 @@ func (a Activity) Message() string {
|
||||
return ServiceUserCreatedMessage
|
||||
case ServiceUserDeleted:
|
||||
return ServiceUserDeletedMessage
|
||||
case UserBlocked:
|
||||
return UserBlockedMessage
|
||||
case UserUnblocked:
|
||||
return UserUnblockedMessage
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
@@ -300,6 +312,10 @@ func (a Activity) StringCode() string {
|
||||
return "user.join"
|
||||
case UserInvited:
|
||||
return "user.invite"
|
||||
case UserBlocked:
|
||||
return "user.block"
|
||||
case UserUnblocked:
|
||||
return "user.unblock"
|
||||
case AccountCreated:
|
||||
return "account.create"
|
||||
case RuleAdded:
|
||||
|
||||
@@ -65,7 +65,7 @@ components:
|
||||
status:
|
||||
description: User's status
|
||||
type: string
|
||||
enum: [ "active","invited","disabled" ]
|
||||
enum: [ "active","invited","blocked" ]
|
||||
auto_groups:
|
||||
description: Groups to auto-assign to peers registered by this user
|
||||
type: array
|
||||
@@ -79,6 +79,9 @@ components:
|
||||
description: Is true if this user is a service user
|
||||
type: boolean
|
||||
readOnly: true
|
||||
is_blocked:
|
||||
description: Is true if this user is blocked. Blocked users can't use the system
|
||||
type: boolean
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
@@ -86,6 +89,7 @@ components:
|
||||
- role
|
||||
- auto_groups
|
||||
- status
|
||||
- is_blocked
|
||||
UserRequest:
|
||||
type: object
|
||||
properties:
|
||||
@@ -97,9 +101,13 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
is_blocked:
|
||||
description: If set to true then user is blocked and can't use the system
|
||||
type: boolean
|
||||
required:
|
||||
- role
|
||||
- auto_groups
|
||||
- is_blocked
|
||||
UserCreateRequest:
|
||||
type: object
|
||||
properties:
|
||||
@@ -645,7 +653,7 @@ components:
|
||||
description: The string code of the activity that occurred during the event
|
||||
type: string
|
||||
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
|
||||
"user.role.update",
|
||||
"user.role.update", "user.block", "user.unblock",
|
||||
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
|
||||
"setupkey.group.delete", "setupkey.group.add",
|
||||
"rule.add", "rule.delete", "rule.update",
|
||||
|
||||
@@ -46,6 +46,7 @@ const (
|
||||
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
|
||||
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
|
||||
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
|
||||
EventActivityCodeUserBlock EventActivityCode = "user.block"
|
||||
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
|
||||
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
|
||||
EventActivityCodeUserInvite EventActivityCode = "user.invite"
|
||||
@@ -53,6 +54,7 @@ const (
|
||||
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
|
||||
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
|
||||
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
|
||||
EventActivityCodeUserUnblock EventActivityCode = "user.unblock"
|
||||
)
|
||||
|
||||
// Defines values for NameserverNsType.
|
||||
@@ -68,9 +70,9 @@ const (
|
||||
|
||||
// Defines values for UserStatus.
|
||||
const (
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDisabled UserStatus = "disabled"
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusBlocked UserStatus = "blocked"
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
)
|
||||
|
||||
// Account defines model for Account.
|
||||
@@ -552,6 +554,9 @@ type User struct {
|
||||
// Id User ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// IsBlocked Is true if this user is blocked. Blocked users can't use the system
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
|
||||
// IsCurrent Is true if authenticated user is the same as this user
|
||||
IsCurrent *bool `json:"is_current,omitempty"`
|
||||
|
||||
@@ -594,6 +599,9 @@ type UserRequest struct {
|
||||
// AutoGroups Groups to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// IsBlocked If set to true then user is blocked and can't use the system
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
|
||||
acMiddleware := middleware.NewAccessControl(
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
accountManager.IsUserAdmin)
|
||||
accountManager.GetUser)
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
@@ -6,28 +6,30 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
isUserAdmin IsUserAdminFunc
|
||||
claimsExtract jwtclaims.ClaimsExtractor
|
||||
getUser GetUser
|
||||
}
|
||||
|
||||
// NewAccessControl instance constructor
|
||||
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
|
||||
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
|
||||
return &AccessControl{
|
||||
isUserAdmin: isUserAdmin,
|
||||
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
),
|
||||
getUser: getUser,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,23 +39,29 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
ok, err := a.isUserAdmin(claims)
|
||||
user, err := a.getUser(claims)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
|
||||
if user.IsBlocked() {
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.IsAdmin() {
|
||||
switch r.Method {
|
||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||
|
||||
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
|
||||
if err != nil {
|
||||
log.Debugf("Regex failed")
|
||||
log.Debugf("regex failed")
|
||||
util.WriteError(status.Errorf(status.Internal, ""), w)
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
log.Debugf("Valid Path")
|
||||
log.Debugf("valid Path")
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ var testAccount = &server.Account{
|
||||
func initPATTestData() *PATHandler {
|
||||
return &PATHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func initPATTestData() *PATHandler {
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testAccount, testAccount.Users[existingUserID], nil
|
||||
},
|
||||
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@@ -91,7 +91,7 @@ func initPATTestData() *PATHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func initPATTestData() *PATHandler {
|
||||
}
|
||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||
},
|
||||
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
|
||||
@@ -61,6 +61,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userRole := server.StrRoleToUserRole(req.Role)
|
||||
if userRole == server.UserRoleUnknown {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
@@ -71,7 +76,9 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
Id: userID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
Blocked: req.IsBlocked,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
@@ -214,7 +221,11 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
case "invited":
|
||||
userStatus = api.UserStatusInvited
|
||||
default:
|
||||
userStatus = api.UserStatusDisabled
|
||||
userStatus = api.UserStatusBlocked
|
||||
}
|
||||
|
||||
if user.IsBlocked {
|
||||
userStatus = api.UserStatusBlocked
|
||||
}
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
@@ -227,5 +238,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
IsBlocked: user.IsBlocked,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package http
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -31,16 +32,19 @@ var usersTestAccount = &server.Account{
|
||||
Id: existingUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
},
|
||||
regularUserID: {
|
||||
Id: regularUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
},
|
||||
serviceUserID: {
|
||||
Id: serviceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
AutoGroups: []string{"group_1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -70,7 +74,7 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return key, nil
|
||||
},
|
||||
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
|
||||
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
@@ -79,6 +83,21 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||
if update.Id == notFoundUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
|
||||
}
|
||||
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
|
||||
info, err := update.Copy().ToUserInfo(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return info, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
@@ -145,6 +164,122 @@ func TestGetUsers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatusCode int
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
expectedUserID string
|
||||
expectedRole string
|
||||
expectedStatus string
|
||||
expectedBlocked bool
|
||||
expectedIsServiceUser bool
|
||||
expectedGroups []string
|
||||
}{
|
||||
{
|
||||
name: "Update_Block_User",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: true,
|
||||
expectedRole: "user",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_1"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
|
||||
},
|
||||
{
|
||||
name: "Update_Change_Role_To_Admin",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_1"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
{
|
||||
name: "Update_Groups",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_2", "group_3"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
{
|
||||
name: "Should_Fail_Because_AutoGroups_Is_Absent",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_2", "group_3"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatusCode {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusOK)
|
||||
}
|
||||
|
||||
if tc.expectedStatusCode == 200 {
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
respBody := &api.User{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("response content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedUserID, respBody.Id)
|
||||
assert.Equal(t, tc.expectedRole, respBody.Role)
|
||||
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
|
||||
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
|
||||
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
|
||||
|
||||
for _, expectedGroup := range tc.expectedGroups {
|
||||
exists := false
|
||||
for _, actualGroup := range respBody.AutoGroups {
|
||||
if expectedGroup == actualGroup {
|
||||
exists = true
|
||||
}
|
||||
}
|
||||
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
name := "name"
|
||||
email := "email"
|
||||
|
||||
@@ -32,10 +32,10 @@ type Auth0Manager struct {
|
||||
// Auth0ClientConfig auth0 manager client configurations
|
||||
type Auth0ClientConfig struct {
|
||||
Audience string
|
||||
AuthIssuer string
|
||||
AuthIssuer string `json:"-"`
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
GrantType string
|
||||
GrantType string `json:"-"`
|
||||
}
|
||||
|
||||
// auth0JWTRequest payload struct to request a JWT Token
|
||||
@@ -110,7 +110,8 @@ type auth0Profile struct {
|
||||
}
|
||||
|
||||
// NewAuth0Manager creates a new instance of the Auth0Manager
|
||||
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
||||
func NewAuth0Manager(oidcConfig OIDCConfig, config Auth0ClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
||||
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
@@ -121,17 +122,19 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
config.AuthIssuer = oidcConfig.TokenEndpoint
|
||||
config.GrantType = "client_credentials"
|
||||
|
||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.Audience == "" || config.AuthIssuer == "" {
|
||||
return nil, fmt.Errorf("auth0 idp configuration is not complete")
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.GrantType != "client_credentials" {
|
||||
return nil, fmt.Errorf("auth0 idp configuration failed. Grant Type should be client_credentials")
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(strings.ToLower(config.AuthIssuer), "https://") {
|
||||
return nil, fmt.Errorf("auth0 idp configuration failed. AuthIssuer should contain https://")
|
||||
if config.Audience == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
|
||||
}
|
||||
|
||||
credentials := &Auth0Credentials{
|
||||
|
||||
@@ -459,26 +459,9 @@ func TestNewAuth0Manager(t *testing.T) {
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
||||
|
||||
testCase3 := test{
|
||||
name: "Wrong Auth Issuer Format",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when wrong auth issuer format",
|
||||
}
|
||||
|
||||
testCase4Config := defaultTestConfig
|
||||
testCase4Config.GrantType = "spa"
|
||||
|
||||
testCase4 := test{
|
||||
name: "Wrong Grant Type",
|
||||
inputConfig: testCase4Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when wrong grant type",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
||||
for _, testCase := range []test{testCase1, testCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
_, err := NewAuth0Manager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -37,12 +37,13 @@ type AzureManager struct {
|
||||
|
||||
// AzureClientConfig azure manager client configurations.
|
||||
type AzureClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
GraphAPIEndpoint string
|
||||
ObjectID string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ObjectID string
|
||||
|
||||
GraphAPIEndpoint string `json:"-"`
|
||||
TokenEndpoint string `json:"-"`
|
||||
GrantType string `json:"-"`
|
||||
}
|
||||
|
||||
// AzureCredentials azure authentication information.
|
||||
@@ -74,7 +75,8 @@ type azureExtension struct {
|
||||
}
|
||||
|
||||
// NewAzureManager creates a new instance of the AzureManager.
|
||||
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||
func NewAzureManager(oidcConfig OIDCConfig, config AzureClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -84,13 +86,20 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
config.GraphAPIEndpoint = "https://graph.microsoft.com"
|
||||
config.GrantType = "client_credentials"
|
||||
|
||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.GraphAPIEndpoint == "" || config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("azure idp configuration is not complete")
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.GrantType != "client_credentials" {
|
||||
return nil, fmt.Errorf("azure idp configuration failed. Grant Type should be client_credentials")
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.ObjectID == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
|
||||
}
|
||||
|
||||
credentials := &AzureCredentials{
|
||||
|
||||
@@ -19,12 +19,21 @@ type Manager interface {
|
||||
GetUserByEmail(email string) ([]*UserData, error)
|
||||
}
|
||||
|
||||
// OIDCConfig specifies configuration for OpenID Connect provider
|
||||
// These configurations are automatically loaded from the OIDC endpoint
|
||||
type OIDCConfig struct {
|
||||
Issuer string
|
||||
TokenEndpoint string
|
||||
}
|
||||
|
||||
// Config an idp configuration struct to be loaded from management server's config file
|
||||
type Config struct {
|
||||
ManagerType string
|
||||
OIDCConfig OIDCConfig `json:"-"`
|
||||
Auth0ClientCredentials Auth0ClientConfig
|
||||
KeycloakClientCredentials KeycloakClientConfig
|
||||
AzureClientCredentials AzureClientConfig
|
||||
KeycloakClientCredentials KeycloakClientConfig
|
||||
ZitadelClientCredentials ZitadelClientConfig
|
||||
}
|
||||
|
||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||
@@ -73,11 +82,13 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
||||
case "none", "":
|
||||
return nil, nil
|
||||
case "auth0":
|
||||
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
|
||||
return NewAuth0Manager(config.OIDCConfig, config.Auth0ClientCredentials, appMetrics)
|
||||
case "azure":
|
||||
return NewAzureManager(config.AzureClientCredentials, appMetrics)
|
||||
return NewAzureManager(config.OIDCConfig, config.AzureClientCredentials, appMetrics)
|
||||
case "keycloak":
|
||||
return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics)
|
||||
return NewKeycloakManager(config.OIDCConfig, config.KeycloakClientCredentials, appMetrics)
|
||||
case "zitadel":
|
||||
return NewZitadelManager(config.OIDCConfig, config.ZitadelClientCredentials, appMetrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||
}
|
||||
|
||||
@@ -37,8 +37,8 @@ type KeycloakClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AdminEndpoint string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
TokenEndpoint string `json:"-"`
|
||||
GrantType string `json:"-"`
|
||||
}
|
||||
|
||||
// KeycloakCredentials keycloak authentication information.
|
||||
@@ -82,7 +82,8 @@ type keycloakProfile struct {
|
||||
}
|
||||
|
||||
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
||||
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||
func NewKeycloakManager(oidcConfig OIDCConfig, config KeycloakClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -92,13 +93,19 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
config.GrantType = "client_credentials"
|
||||
|
||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak idp configuration is not complete")
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.GrantType != "client_credentials" {
|
||||
return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials")
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.AdminEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
|
||||
}
|
||||
|
||||
credentials := &KeycloakCredentials{
|
||||
|
||||
@@ -46,19 +46,19 @@ func TestNewKeycloakManager(t *testing.T) {
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase5Config := defaultTestConfig
|
||||
testCase5Config.GrantType = "authorization_code"
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.ClientSecret = ""
|
||||
|
||||
testCase5 := test{
|
||||
name: "Wrong GrantType",
|
||||
inputConfig: testCase5Config,
|
||||
testCase3 := test{
|
||||
name: "Missing ClientSecret Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when wrong grant type",
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase5} {
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
_, err := NewKeycloakManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
|
||||
609
management/server/idp/zitadel.go
Normal file
609
management/server/idp/zitadel.go
Normal file
@@ -0,0 +1,609 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ZitadelManager zitadel manager client instance.
|
||||
type ZitadelManager struct {
|
||||
managementEndpoint string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// ZitadelClientConfig zitadel manager client configurations.
|
||||
type ZitadelClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
GrantType string `json:"-"`
|
||||
TokenEndpoint string `json:"-"`
|
||||
ManagementEndpoint string `json:"-"`
|
||||
}
|
||||
|
||||
// ZitadelCredentials zitadel authentication information.
|
||||
type ZitadelCredentials struct {
|
||||
clientConfig ZitadelClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// zitadelEmail specifies details of a user email.
|
||||
type zitadelEmail struct {
|
||||
Email string `json:"email"`
|
||||
IsEmailVerified bool `json:"isEmailVerified"`
|
||||
}
|
||||
|
||||
// zitadelUserInfo specifies user information.
|
||||
type zitadelUserInfo struct {
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
|
||||
// zitadelUser specifies profile details for user account.
|
||||
type zitadelUser struct {
|
||||
UserName string `json:"userName,omitempty"`
|
||||
Profile zitadelUserInfo `json:"profile"`
|
||||
Email zitadelEmail `json:"email"`
|
||||
}
|
||||
|
||||
type zitadelAttributes map[string][]map[string]any
|
||||
|
||||
// zitadelMetadata holds additional user data.
|
||||
type zitadelMetadata struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// zitadelProfile represents an zitadel user profile response.
|
||||
type zitadelProfile struct {
|
||||
ID string `json:"id"`
|
||||
State string `json:"state"`
|
||||
UserName string `json:"userName"`
|
||||
PreferredLoginName string `json:"preferredLoginName"`
|
||||
LoginNames []string `json:"loginNames"`
|
||||
Human *zitadelUser `json:"human"`
|
||||
Metadata []zitadelMetadata
|
||||
}
|
||||
|
||||
// NewZitadelManager creates a new instance of the ZitadelManager.
|
||||
func NewZitadelManager(oidcConfig OIDCConfig, config ZitadelClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
config.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
config.ManagementEndpoint = fmt.Sprintf("%s/management/v1", oidcConfig.Issuer)
|
||||
config.GrantType = "client_credentials"
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("zitadel IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
credentials := &ZitadelCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &ZitadelManager{
|
||||
managementEndpoint: config.ManagementEndpoint,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from zitadel.
|
||||
func (zc *ZitadelCredentials) jwtStillValid() bool {
|
||||
return !zc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(zc.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", zc.clientConfig.ClientID)
|
||||
data.Set("client_secret", zc.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", zc.clientConfig.GrantType)
|
||||
data.Set("scope", "urn:zitadel:iam:org:project:id:zitadel:aud")
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, zc.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for zitadel idp manager")
|
||||
|
||||
resp, err := zc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if zc.appMetrics != nil {
|
||||
zc.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds.
|
||||
func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = zc.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = zc.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the Zitadel Management API.
|
||||
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
zc.mux.Lock()
|
||||
defer zc.mux.Unlock()
|
||||
|
||||
if zc.appMetrics != nil {
|
||||
zc.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if zc.jwtStillValid() {
|
||||
return zc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := zc.requestJWTToken()
|
||||
if err != nil {
|
||||
return zc.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := zc.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return zc.jwtToken, err
|
||||
}
|
||||
|
||||
zc.jwtToken = jwtToken
|
||||
|
||||
return zc.jwtToken, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in zitadel Idp and sends an invite.
|
||||
func (zm *ZitadelManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
|
||||
payload, err := buildZitadelCreateUserRequestPayload(email, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := zm.post("users/human/_import", payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
var result struct {
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
err = zm.helper.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
invite := true
|
||||
appMetadata := AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &invite,
|
||||
}
|
||||
|
||||
// Add metadata to new user
|
||||
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return zm.GetUserDataByID(result.UserID, appMetadata)
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
searchByEmail := zitadelAttributes{
|
||||
"queries": {
|
||||
{
|
||||
"emailQuery": map[string]any{
|
||||
"emailAddress": email,
|
||||
"method": "TEXT_QUERY_METHOD_EQUALS",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
payload, err := zm.helper.Marshal(searchByEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := zm.post("users/_search", string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
var profiles struct{ Result []zitadelProfile }
|
||||
err = zm.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.Result {
|
||||
metadata, err := zm.getUserMetadata(profile.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile.Metadata = metadata
|
||||
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from zitadel via ID.
|
||||
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := zm.get("users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var profile struct{ User zitadelProfile }
|
||||
err = zm.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metadata, err := zm.getUserMetadata(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile.User.Metadata = metadata
|
||||
|
||||
return profile.User.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
accounts, err := zm.GetAllAccounts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
return accounts[accountID], nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
body, err := zm.post("users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
var profiles struct{ Result []zitadelProfile }
|
||||
err = zm.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range profiles.Result {
|
||||
// fetch user metadata
|
||||
metadata, err := zm.getUserMetadata(profile.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile.Metadata = metadata
|
||||
|
||||
userData := profile.userData()
|
||||
accountID := userData.AppMetadata.WTAccountID
|
||||
|
||||
if accountID != "" {
|
||||
if _, ok := indexedUsers[accountID]; !ok {
|
||||
indexedUsers[accountID] = make([]*UserData, 0)
|
||||
}
|
||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||
}
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
// Metadata values are base64 encoded.
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
if appMetadata.WTPendingInvite == nil {
|
||||
appMetadata.WTPendingInvite = new(bool)
|
||||
}
|
||||
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
|
||||
|
||||
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
|
||||
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
|
||||
|
||||
metadata := zitadelAttributes{
|
||||
"metadata": {
|
||||
{
|
||||
"key": wtAccountID,
|
||||
"value": wtAccountIDValue,
|
||||
},
|
||||
{
|
||||
"key": wtPendingInvite,
|
||||
"value": wtPendingInviteValue,
|
||||
},
|
||||
},
|
||||
}
|
||||
payload, err := zm.helper.Marshal(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
|
||||
_, err = zm.post(resource, string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getUserMetadata requests user metadata from zitadel via ID.
|
||||
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
||||
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
||||
body, err := zm.post(resource, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var metadata struct{ Result []zitadelMetadata }
|
||||
err = zm.helper.Unmarshal(body, &metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return metadata.Result, nil
|
||||
}
|
||||
|
||||
// post perform Post requests.
|
||||
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
|
||||
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := zm.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s?%s", zm.managementEndpoint, resource, q.Encode())
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := zm.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// value returns string represented by the base64 string value.
|
||||
func (zm zitadelMetadata) value() string {
|
||||
value, err := base64.StdEncoding.DecodeString(zm.Value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(value)
|
||||
}
|
||||
|
||||
// userData construct user data from zitadel profile.
|
||||
func (zp zitadelProfile) userData() *UserData {
|
||||
var (
|
||||
email string
|
||||
name string
|
||||
wtAccountIDValue string
|
||||
wtPendingInviteValue bool
|
||||
)
|
||||
|
||||
for _, metadata := range zp.Metadata {
|
||||
if metadata.Key == wtAccountID {
|
||||
wtAccountIDValue = metadata.value()
|
||||
}
|
||||
|
||||
if metadata.Key == wtPendingInvite {
|
||||
value, err := strconv.ParseBool(metadata.value())
|
||||
if err == nil {
|
||||
wtPendingInviteValue = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Obtain the email for the human account and the login name,
|
||||
// for the machine account.
|
||||
if zp.Human != nil {
|
||||
email = zp.Human.Email.Email
|
||||
name = zp.Human.Profile.DisplayName
|
||||
} else {
|
||||
if len(zp.LoginNames) > 0 {
|
||||
email = zp.LoginNames[0]
|
||||
name = zp.LoginNames[0]
|
||||
}
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: zp.ID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: wtAccountIDValue,
|
||||
WTPendingInvite: &wtPendingInviteValue,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
|
||||
var firstName, lastName string
|
||||
|
||||
words := strings.Fields(name)
|
||||
if n := len(words); n > 0 {
|
||||
firstName = strings.Join(words[:n-1], " ")
|
||||
lastName = words[n-1]
|
||||
}
|
||||
|
||||
req := &zitadelUser{
|
||||
UserName: name,
|
||||
Profile: zitadelUserInfo{
|
||||
FirstName: strings.TrimSpace(firstName),
|
||||
LastName: strings.TrimSpace(lastName),
|
||||
DisplayName: name,
|
||||
},
|
||||
Email: zitadelEmail{
|
||||
Email: email,
|
||||
IsEmailVerified: false,
|
||||
},
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
486
management/server/idp/zitadel_test.go
Normal file
486
management/server/idp/zitadel_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewZitadelManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig ZitadelClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := ZitadelClientConfig{
|
||||
ClientID: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
GrantType: "client_credentials",
|
||||
TokenEndpoint: "http://localhost/oauth/v2/token",
|
||||
ManagementEndpoint: "http://localhost/management/v1",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing ClientID Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.ClientSecret = ""
|
||||
|
||||
testCase3 := test{
|
||||
name: "Missing ClientSecret Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewZitadelManager(OIDCConfig{}, testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockZitadelCredentials struct {
|
||||
jwtToken JWTToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
func TestZitadelRequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputRespBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputRespBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := ZitadelClientConfig{}
|
||||
|
||||
creds := ZitadelCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken()
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZitadelParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputRespBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||
config := ZitadelClientConfig{}
|
||||
|
||||
creds := ZitadelCredentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZitadelJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := ZitadelClientConfig{}
|
||||
|
||||
creds := ZitadelCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZitadelAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := ZitadelClientConfig{}
|
||||
|
||||
creds := ZitadelCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
|
||||
type updateUserAppMetadataTest struct {
|
||||
name string
|
||||
inputReqBody string
|
||||
expectedReqBody string
|
||||
appMetadata AppMetadata
|
||||
statusCode int
|
||||
helper ManagerHelper
|
||||
managerCreds ManagerCredentials
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
expectedReqBody: "",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockZitadelCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Response Parsing",
|
||||
statusCode: 400,
|
||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||
managerCreds: &mockZitadelCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockZitadelCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
invite := true
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Update Pending Invite",
|
||||
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
|
||||
appMetadata: AppMetadata{
|
||||
WTAccountID: "ok",
|
||||
WTPendingInvite: &invite,
|
||||
},
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockZitadelCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
reqClient := mockHTTPClient{
|
||||
resBody: testCase.inputReqBody,
|
||||
code: testCase.statusCode,
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
httpClient: &reqClient,
|
||||
credentials: testCase.managerCreds,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZitadelProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
name string
|
||||
invite bool
|
||||
inputProfile zitadelProfile
|
||||
expectedUserData UserData
|
||||
}
|
||||
|
||||
azureProfileTestCase1 := azureProfileTest{
|
||||
name: "User Request",
|
||||
invite: false,
|
||||
inputProfile: zitadelProfile{
|
||||
ID: "test1",
|
||||
State: "USER_STATE_ACTIVE",
|
||||
UserName: "test1@mail.com",
|
||||
PreferredLoginName: "test1@mail.com",
|
||||
LoginNames: []string{
|
||||
"test1@mail.com",
|
||||
},
|
||||
Human: &zitadelUser{
|
||||
Profile: zitadelUserInfo{
|
||||
FirstName: "ZITADEL",
|
||||
LastName: "Admin",
|
||||
DisplayName: "ZITADEL Admin",
|
||||
},
|
||||
Email: zitadelEmail{
|
||||
Email: "test1@mail.com",
|
||||
IsEmailVerified: true,
|
||||
},
|
||||
},
|
||||
Metadata: []zitadelMetadata{
|
||||
{
|
||||
Key: "wt_account_id",
|
||||
Value: "MQ==",
|
||||
},
|
||||
{
|
||||
Key: "wt_pending_invite",
|
||||
Value: "ZmFsc2U=",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test1",
|
||||
Name: "ZITADEL Admin",
|
||||
Email: "test1@mail.com",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase2 := azureProfileTest{
|
||||
name: "Service User Request",
|
||||
invite: true,
|
||||
inputProfile: zitadelProfile{
|
||||
ID: "test2",
|
||||
State: "USER_STATE_ACTIVE",
|
||||
UserName: "machine",
|
||||
PreferredLoginName: "machine",
|
||||
LoginNames: []string{
|
||||
"machine",
|
||||
},
|
||||
Human: nil,
|
||||
Metadata: []zitadelMetadata{
|
||||
{
|
||||
Key: "wt_account_id",
|
||||
Value: "MQ==",
|
||||
},
|
||||
{
|
||||
Key: "wt_pending_invite",
|
||||
Value: "dHJ1ZQ==",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test2",
|
||||
Name: "machine",
|
||||
Email: "machine",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
userData := testCase.inputProfile.userData()
|
||||
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,7 @@ type MockAccountManager struct {
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||
@@ -60,11 +60,11 @@ type MockAccountManager struct {
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
@@ -190,33 +190,33 @@ func (am *MockAccountManager) MarkPATUsed(pat string) error {
|
||||
}
|
||||
|
||||
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if am.CreatePATFunc != nil {
|
||||
return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn)
|
||||
return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
|
||||
}
|
||||
|
||||
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if am.DeletePATFunc != nil {
|
||||
return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID)
|
||||
return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
|
||||
}
|
||||
|
||||
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if am.GetPATFunc != nil {
|
||||
return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID)
|
||||
return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
|
||||
}
|
||||
|
||||
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if am.GetAllPATsFunc != nil {
|
||||
return am.GetAllPATsFunc(accountID, executingUserID, targetUserID)
|
||||
return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
|
||||
}
|
||||
@@ -385,12 +385,12 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
|
||||
}
|
||||
|
||||
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
|
||||
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
if am.IsUserAdminFunc != nil {
|
||||
return am.IsUserAdminFunc(claims)
|
||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||
if am.GetUserFunc != nil {
|
||||
return am.GetUserFunc(claims)
|
||||
}
|
||||
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
||||
@@ -493,9 +493,9 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
|
||||
}
|
||||
|
||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
|
||||
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.DeleteUserFunc != nil {
|
||||
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
|
||||
return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||
}
|
||||
|
||||
@@ -605,6 +605,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if peerLoginExpired(peer, account) {
|
||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
@@ -644,6 +649,11 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
updateRemotePeers := false
|
||||
if peerLoginExpired(peer, account) {
|
||||
err = checkAuth(login.UserID, peer)
|
||||
@@ -676,6 +686,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
||||
}
|
||||
|
||||
func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error {
|
||||
if peer.AddedWithSSOLogin() {
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.PermissionDenied, "user doesn't exist")
|
||||
}
|
||||
if user.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "user is blocked")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkAuth(loginUserID string, peer *Peer) error {
|
||||
if loginUserID == "" {
|
||||
// absence of a user ID indicates that JWT wasn't provided.
|
||||
|
||||
@@ -51,15 +51,22 @@ type User struct {
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string
|
||||
PATs map[string]*PersonalAccessToken
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
}
|
||||
|
||||
// IsAdmin returns true if user is an admin, false otherwise
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
func (u *User) IsBlocked() bool {
|
||||
return u.Blocked
|
||||
}
|
||||
|
||||
// IsAdmin returns true if the user is an admin, false otherwise
|
||||
func (u *User) IsAdmin() bool {
|
||||
return u.Role == UserRoleAdmin
|
||||
}
|
||||
|
||||
// toUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
// ToUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
autoGroups := u.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
@@ -74,6 +81,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
@@ -93,6 +101,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -113,6 +122,7 @@ func (u *User) Copy() *User {
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
Blocked: u.Blocked,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,7 +148,7 @@ func NewAdminUser(id string) *User {
|
||||
}
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -147,7 +157,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
|
||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if executingUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
@@ -166,7 +176,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": newUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||
am.storeEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||
|
||||
return &UserInfo{
|
||||
ID: newUser.Id,
|
||||
@@ -212,7 +222,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
||||
}
|
||||
|
||||
if user != nil {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||
}
|
||||
|
||||
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
||||
@@ -221,7 +231,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
||||
}
|
||||
|
||||
if len(users) > 0 {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||
}
|
||||
|
||||
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
||||
@@ -249,12 +259,27 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
||||
|
||||
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
|
||||
|
||||
return newUser.toUserInfo(idpUser)
|
||||
return newUser.ToUserInfo(idpUser)
|
||||
|
||||
}
|
||||
|
||||
// GetUser looks up a user by provided authorization claims.
|
||||
// It will also create an account if didn't exist for this user before.
|
||||
func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||
account, _, err := am.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user from the given account.
|
||||
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
|
||||
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -268,7 +293,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
@@ -281,7 +306,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
|
||||
@@ -294,7 +319,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
||||
}
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -316,12 +341,12 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
||||
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
@@ -338,13 +363,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
|
||||
return pat, nil
|
||||
}
|
||||
|
||||
// DeletePAT deletes a specific PAT from a user
|
||||
func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -358,12 +383,12 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if targetUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||
}
|
||||
|
||||
@@ -382,7 +407,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||
|
||||
delete(targetUser.PATs, tokenID)
|
||||
|
||||
@@ -394,7 +419,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
}
|
||||
|
||||
// GetPAT returns a specific PAT from a user
|
||||
func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -408,12 +433,12 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
||||
}
|
||||
|
||||
@@ -426,7 +451,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
||||
}
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -440,12 +465,12 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
@@ -457,9 +482,9 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
||||
return pats, nil
|
||||
}
|
||||
|
||||
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
|
||||
// Only User.AutoGroups field is allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) {
|
||||
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
|
||||
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -472,56 +497,102 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
||||
return nil, err
|
||||
}
|
||||
|
||||
initiatorUser, err := account.FindUser(initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !initiatorUser.IsAdmin() || initiatorUser.IsBlocked() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only admins are authorized to perform user update operations")
|
||||
}
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
|
||||
}
|
||||
|
||||
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
|
||||
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && update.Role != UserRoleAdmin {
|
||||
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newUser := oldUser.Copy()
|
||||
newUser.Role = update.Role
|
||||
newUser.Blocked = update.Blocked
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
if _, ok := account.Groups[newGroupID]; !ok {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
}
|
||||
}
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "update not found")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newUser := oldUser.Copy()
|
||||
newUser.AutoGroups = update.AutoGroups
|
||||
newUser.Role = update.Role
|
||||
|
||||
account.Users[newUser.Id] = newUser
|
||||
|
||||
if !oldUser.IsBlocked() && update.IsBlocked() {
|
||||
// expire peers that belong to the user who's getting blocked
|
||||
blockedPeers, err := account.FindUserPeers(update.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var peerIDs []string
|
||||
for _, peer := range blockedPeers {
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
peer.MarkLoginExpired(true)
|
||||
account.UpdatePeer(peer)
|
||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
||||
if err != nil {
|
||||
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
|
||||
return nil, err
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// store activity logs
|
||||
if oldUser.Role != newUser.Role {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||
}
|
||||
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
if update.AutoGroups != nil {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
||||
@@ -532,9 +603,9 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
||||
if userData == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
|
||||
}
|
||||
return newUser.toUserInfo(userData)
|
||||
return newUser.ToUserInfo(userData)
|
||||
}
|
||||
return newUser.toUserInfo(nil)
|
||||
return newUser.ToUserInfo(nil)
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||
@@ -574,21 +645,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
|
||||
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
account, _, err := am.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get account: %v", err)
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
if !ok {
|
||||
return false, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
return user.Role == UserRoleAdmin, nil
|
||||
}
|
||||
|
||||
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
|
||||
// based on provided user role.
|
||||
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
|
||||
@@ -625,7 +681,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
// if user is not an admin then show only current user and do not show other users
|
||||
continue
|
||||
}
|
||||
info, err := accountUser.toUserInfo(nil)
|
||||
info, err := accountUser.ToUserInfo(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -642,7 +698,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
|
||||
var info *UserInfo
|
||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||
info, err = localUser.toUserInfo(queriedUser)
|
||||
info, err = localUser.ToUserInfo(queriedUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@@ -265,6 +266,7 @@ func TestUser_Copy(t *testing.T) {
|
||||
LastUsed: time.Now(),
|
||||
},
|
||||
},
|
||||
Blocked: false,
|
||||
}
|
||||
|
||||
err := validateStruct(user)
|
||||
@@ -288,7 +290,7 @@ func validateStruct(s interface{}) (err error) {
|
||||
field := structVal.Field(i)
|
||||
fieldName := structType.Field(i).Name
|
||||
|
||||
isSet := field.IsValid() && !field.IsZero()
|
||||
isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool")
|
||||
|
||||
if !isSet {
|
||||
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
|
||||
@@ -440,7 +442,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
||||
}
|
||||
|
||||
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
@@ -458,42 +460,23 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
||||
UserId: mockUserID,
|
||||
}
|
||||
|
||||
ok, err := am.IsUserAdmin(claims)
|
||||
user, err := am.GetUser(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, mockUserID, user.Id)
|
||||
assert.True(t, user.IsAdmin())
|
||||
assert.False(t, user.IsBlocked())
|
||||
}
|
||||
|
||||
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
Role: "user",
|
||||
}
|
||||
func TestUser_IsAdmin(t *testing.T) {
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
user := NewAdminUser(mockUserID)
|
||||
assert.True(t, user.IsAdmin())
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: mockUserID,
|
||||
}
|
||||
|
||||
ok, err := am.IsUserAdmin(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.False(t, ok)
|
||||
user = NewRegularUser(mockUserID)
|
||||
assert.False(t, user.IsAdmin())
|
||||
}
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
@@ -550,3 +533,103 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
assert.Equal(t, 1, len(users))
|
||||
assert.Equal(t, mockServiceUserID, users[0].ID)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
regularUserID := "regularUser"
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
adminInitiator bool
|
||||
update *User
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Should_Fail_To_Update_Admin_Role",
|
||||
expectedErr: true,
|
||||
adminInitiator: true,
|
||||
update: &User{
|
||||
Id: userID,
|
||||
Role: UserRoleUser,
|
||||
Blocked: false,
|
||||
},
|
||||
}, {
|
||||
name: "Should_Fail_When_Admin_Blocks_Themselves",
|
||||
expectedErr: true,
|
||||
adminInitiator: true,
|
||||
update: &User{
|
||||
Id: userID,
|
||||
Role: UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should_Fail_To_Update_Non_Existing_User",
|
||||
expectedErr: true,
|
||||
adminInitiator: true,
|
||||
update: &User{
|
||||
Id: userID,
|
||||
Role: UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
|
||||
expectedErr: true,
|
||||
adminInitiator: false,
|
||||
update: &User{
|
||||
Id: userID,
|
||||
Role: UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should_Update_User",
|
||||
expectedErr: false,
|
||||
adminInitiator: true,
|
||||
update: &User{
|
||||
Id: regularUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
// create an account and an admin user
|
||||
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.io")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a regular user
|
||||
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
||||
err = manager.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
initiatorID := userID
|
||||
if !tc.adminInitiator {
|
||||
initiatorID = regularUserID
|
||||
}
|
||||
|
||||
updated, err := manager.SaveUser(account.Id, initiatorID, tc.update)
|
||||
if tc.expectedErr {
|
||||
require.Errorf(t, err, "expecting SaveUser to throw an error")
|
||||
} else {
|
||||
require.NoError(t, err, "expecting SaveUser not to throw an error")
|
||||
assert.NotNil(t, updated)
|
||||
|
||||
assert.Equal(t, string(tc.update.Role), updated.Role)
|
||||
assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ download_release_binary() {
|
||||
echo "Installing $1 from $DOWNLOAD_URL"
|
||||
cd /tmp && curl -LO "$DOWNLOAD_URL"
|
||||
|
||||
|
||||
if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then
|
||||
INSTALL_DIR="/Applications/NetBird UI.app"
|
||||
|
||||
@@ -43,8 +44,9 @@ download_release_binary() {
|
||||
unzip -q -o "$BINARY_NAME"
|
||||
mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR"
|
||||
else
|
||||
sudo mkdir -p "$INSTALL_DIR"
|
||||
tar -xzvf "$BINARY_NAME"
|
||||
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR"
|
||||
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/"
|
||||
fi
|
||||
}
|
||||
|
||||
@@ -281,4 +283,4 @@ install_netbird() {
|
||||
echo "sudo netbird up"
|
||||
}
|
||||
|
||||
install_netbird
|
||||
install_netbird
|
||||
|
||||
@@ -2,44 +2,44 @@ package sharedsock
|
||||
|
||||
import "golang.org/x/net/bpf"
|
||||
|
||||
// STUNFilter implements BPFFilter by filtering on STUN packets.
|
||||
// IncomingSTUNFilter implements BPFFilter and filters out anything but incoming STUN packets to a specified destination port.
|
||||
// Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard).
|
||||
type STUNFilter struct {
|
||||
type IncomingSTUNFilter struct {
|
||||
}
|
||||
|
||||
// NewSTUNFilter creates an instance of a STUNFilter
|
||||
func NewSTUNFilter() BPFFilter {
|
||||
return &STUNFilter{}
|
||||
// NewIncomingSTUNFilter creates an instance of a IncomingSTUNFilter
|
||||
func NewIncomingSTUNFilter() BPFFilter {
|
||||
return &IncomingSTUNFilter{}
|
||||
}
|
||||
|
||||
// GetInstructions returns raw BPF instructions for ipv4 and ipv6 that filter out anything but STUN packets
|
||||
func (sf STUNFilter) GetInstructions(port uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) {
|
||||
raw4, err = rawInstructions(22, 32, port)
|
||||
func (filter *IncomingSTUNFilter) GetInstructions(dstPort uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) {
|
||||
raw4, err = rawInstructions(22, 32, dstPort)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
raw6, err = rawInstructions(2, 12, port)
|
||||
raw6, err = rawInstructions(2, 12, dstPort)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return raw4, raw6, nil
|
||||
}
|
||||
|
||||
func rawInstructions(portOff, cookieOff, port uint32) ([]bpf.RawInstruction, error) {
|
||||
func rawInstructions(dstPortOff, cookieOff, dstPort uint32) ([]bpf.RawInstruction, error) {
|
||||
// UDP raw socket for ipv4 receives the rcvdPacket with IP headers
|
||||
// UDP raw socket for ipv6 receives the rcvdPacket with UDP headers
|
||||
instructions := []bpf.Instruction{
|
||||
// Load the source port from the UDP header (offset 22 for ipv4 and 2 for ipv6)
|
||||
bpf.LoadAbsolute{Off: portOff, Size: 2},
|
||||
// Check if the source port is equal to the specified `port`. If not, skip the next 3 instructions.
|
||||
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: port, SkipTrue: 3},
|
||||
// Load the destination port from the UDP header (offset 22 for ipv4 and 2 for ipv6)
|
||||
bpf.LoadAbsolute{Off: dstPortOff, Size: 2},
|
||||
// Check if the destination port is equal to the specified `dstPort`. If not, skip the next 3 instructions.
|
||||
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: dstPort, SkipTrue: 3},
|
||||
// Load the 4-byte value (magic cookie) from the UDP payload (offset 32 for ipv4 and 12 for ipv6)
|
||||
bpf.LoadAbsolute{Off: cookieOff, Size: 4},
|
||||
// Check if the loaded value is equal to the `magicCookie`. If not, skip the next instruction.
|
||||
bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: magicCookie, SkipTrue: 1},
|
||||
// If both the port and the magic cookie match, return a positive value (0xffffffff)
|
||||
// If both the dstPort and the magic cookie match, return a positive value (0xffffffff)
|
||||
bpf.RetConstant{Val: 0xffffffff},
|
||||
// If either the port or the magic cookie doesn't match, return 0
|
||||
// If either the dstPort or the magic cookie doesn't match, return 0
|
||||
bpf.RetConstant{Val: 0},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
package sharedsock
|
||||
|
||||
// NewSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux
|
||||
func NewSTUNFilter() BPFFilter {
|
||||
// NewIncomingSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux
|
||||
func NewIncomingSTUNFilter() BPFFilter {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,7 +27,8 @@ import (
|
||||
var ErrSharedSockStopped = fmt.Errorf("shared socked stopped")
|
||||
|
||||
// SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered
|
||||
// by BPF instructions (e.g., STUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)).
|
||||
// by BPF instructions (e.g., IncomingSTUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)).
|
||||
// It is meant to be used when sharing a port with some other process.
|
||||
type SharedSocket struct {
|
||||
ctx context.Context
|
||||
conn4 *socket.Conn
|
||||
|
||||
@@ -21,7 +21,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) {
|
||||
|
||||
// create raw socket on a port
|
||||
testingPort := 51821
|
||||
rawSock, err := Listen(testingPort, NewSTUNFilter())
|
||||
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||
err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
require.NoError(t, err, "unable to set deadline, error: %s", err)
|
||||
@@ -76,7 +76,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) {
|
||||
|
||||
func TestShouldNotReadNonSTUNPackets(t *testing.T) {
|
||||
testingPort := 39439
|
||||
rawSock, err := Listen(testingPort, NewSTUNFilter())
|
||||
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||
defer rawSock.Close()
|
||||
|
||||
@@ -110,7 +110,7 @@ func TestWriteTo(t *testing.T) {
|
||||
defer udpListener.Close()
|
||||
|
||||
testingPort := 39440
|
||||
rawSock, err := Listen(testingPort, NewSTUNFilter())
|
||||
rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
|
||||
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||
defer rawSock.Close()
|
||||
|
||||
@@ -144,7 +144,7 @@ func TestWriteTo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSharedSocket_Close(t *testing.T) {
|
||||
rawSock, err := Listen(39440, NewSTUNFilter())
|
||||
rawSock, err := Listen(39440, NewIncomingSTUNFilter())
|
||||
require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
|
||||
|
||||
errGrp := errgroup.Group{}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux
|
||||
//go:build !linux || android
|
||||
|
||||
package sharedsock
|
||||
|
||||
|
||||
@@ -350,7 +350,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||
|
||||
decryptedMessage, err := c.decryptMessage(msg)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user