Compare commits

...

10 Commits

Author SHA1 Message Date
Zoltán Papp
8ac5e9d866 Fix log 2024-10-31 19:07:38 +01:00
Zoltán Papp
954e038da0 Add more logs 2024-10-31 18:19:57 +01:00
Zoltán Papp
9ccc6c6547 Add nil value check 2024-10-31 16:48:10 +01:00
Zoltan Papp
2a3262f5a8 Print debug info 2024-10-29 13:54:35 +01:00
pascal-fischer
10480eb52f [management] Setup key improvements (#2775) 2024-10-28 17:52:23 +01:00
pascal-fischer
1e44c5b574 [client] allow relay leader on iOS (#2795) 2024-10-28 16:55:00 +01:00
Viktor Liu
940f8b4547 [client] Remove legacy forwarding rules in userspace mode (#2782) 2024-10-28 12:29:29 +01:00
Viktor Liu
46e37fa04c [client] Ignore route rules with no sources instead of erroring out (#2786) 2024-10-28 12:28:44 +01:00
Stefano
b9f205b2ce [misc] Update Zitadel from v2.54.10 to v2.64.1 2024-10-28 10:08:17 +01:00
Viktor Liu
0fd874fa45 [client] Make native firewall init fail firewall creation (#2784) 2024-10-28 10:02:27 +01:00
33 changed files with 681 additions and 265 deletions

View File

@@ -3,6 +3,7 @@
package firewall package firewall
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
@@ -37,62 +38,55 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
fm, errFw := createNativeFirewall(iface) fm, err := createNativeFirewall(iface, stateManager)
if fm != nil { if !iface.IsUserspaceBind() {
if err := fm.Init(stateManager); err != nil { return fm, err
log.Errorf("failed to init nftables manager: %s", err)
}
} }
if iface.IsUserspaceBind() { if err != nil {
return createUserspaceFirewall(iface, fm, errFw) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
} }
return fm, errFw func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
} }
func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) { if err = fm.Init(stateManager); err != nil {
return nil, fmt.Errorf("init firewall: %s", err)
}
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
switch check() { switch check() {
case IPTABLES: case IPTABLES:
return createIptablesFirewall(iface) log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface)
case NFTABLES: case NFTABLES:
return createNftablesFirewall(iface) log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface)
default: default:
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, fmt.Errorf("no firewall manager found") return nil, errors.New("no firewall manager found")
} }
} }
func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
log.Info("creating an iptables firewall manager")
fm, err := nbiptables.Create(iface)
if err != nil {
log.Errorf("failed to create iptables manager: %s", err)
}
return fm, err
}
func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
log.Info("creating an nftables firewall manager")
fm, err := nbnftables.Create(iface)
if err != nil {
log.Errorf("failed to create nftables manager: %s", err)
}
return fm, err
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) {
var errUsp error var errUsp error
if errFw == nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else { } else {
fm, errUsp = uspfilter.Create(iface) fm, errUsp = uspfilter.Create(iface)
} }
if errUsp != nil { if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp) return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
return nil, errUsp
} }
if err := fm.AllowNetbird(); err != nil { if err := fm.AllowNetbird(); err != nil {

View File

@@ -296,6 +296,8 @@ func (r *router) RemoveAllLegacyRouteRules() error {
} }
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
} }
} }

View File

@@ -230,23 +230,7 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management // SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
oldLegacy := m.router.legacyManagement return firewall.SetLegacyManagement(m.router, isLegacy)
if oldLegacy != isLegacy {
m.router.legacyManagement = isLegacy
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
if !isLegacy && oldLegacy {
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
} }
// Reset firewall to the default state // Reset firewall to the default state

View File

@@ -551,7 +551,10 @@ func (r *router) RemoveAllLegacyRouteRules() error {
} }
if err := r.conn.DelRule(rule); err != nil { if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
} }
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }

View File

@@ -237,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
} }
// SetLegacyManagement doesn't need to be implemented for this manager // SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(_ bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return nil if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
} }
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager

View File

@@ -3,6 +3,7 @@ package acl
import ( import (
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -10,14 +11,18 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager // Manager is a ACL rules manager
type Manager interface { type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap) ApplyFiltering(networkMap *mgmProto.NetworkMap)
@@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
} }
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
var newRouteRules = make(map[id.RuleID]struct{}) newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules { for _, rule := range rules {
id, err := d.applyRouteACL(rule) id, err := d.applyRouteACL(rule)
if err != nil { if err != nil {
return fmt.Errorf("apply route ACL: %w", err) if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
}
continue
} }
newRouteRules[id] = struct{}{} newRouteRules[id] = struct{}{}
} }
// Clean up old firewall rules
for id := range d.routeRules { for id := range d.routeRules {
if _, ok := newRouteRules[id]; !ok { if _, exists := newRouteRules[id]; !exists {
if err := d.firewall.DeleteRouteRule(id); err != nil { if err := d.firewall.DeleteRouteRule(id); err != nil {
log.Errorf("failed to delete route firewall rule: %v", err) merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
continue
} }
delete(d.routeRules, id) // implicitly deleted from the map
} }
} }
d.routeRules = newRouteRules d.routeRules = newRouteRules
return nil return nberrors.FormatErrorOrNil(merr)
} }
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 { if len(rule.SourceRanges) == 0 {
return "", fmt.Errorf("source ranges is empty") return "", ErrSourceRangesEmpty
} }
var sources []netip.Prefix var sources []netip.Prefix

View File

@@ -333,9 +333,12 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
ep = wgProxy.EndpointAddr() ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy conn.wgProxyICE = wgProxy
} else { } else {
conn.log.Infof("direct iceConnInfo: %v", iceConnInfo.RemoteConn)
agentCheck(conn.log, iceConnInfo.Agent)
nilCheck(conn.log, iceConnInfo.RemoteConn)
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
if err != nil { if err != nil {
log.Errorf("failed to resolveUDPaddr") conn.log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil) conn.handleConfigurationFailure(err, nil)
return return
} }

View File

@@ -1,13 +1,14 @@
package ice package ice
import ( import (
"github.com/netbirdio/netbird/client/internal/stdnet" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
"github.com/pion/randutil" "github.com/pion/randutil"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"runtime"
"time" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
const ( const (
@@ -77,10 +78,7 @@ func CandidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() { if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay} return []ice.CandidateType{ice.CandidateTypeRelay}
} }
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
} }

View File

@@ -0,0 +1,54 @@
package peer
import (
"net"
"reflect"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
)
func nilCheck(log *log.Entry, conn net.Conn) {
if conn == nil {
log.Infof("conn is nil")
return
}
if conn.RemoteAddr() == nil {
log.Infof("conn.RemoteAddr() is nil")
return
}
if reflect.ValueOf(conn.RemoteAddr()).IsNil() {
log.Infof("value of conn.RemoteAddr() is nil")
return
}
}
func agentCheck(log *log.Entry, agent *ice.Agent) {
if agent == nil {
log.Errorf("agent is nil")
return
}
pair, err := agent.GetSelectedCandidatePair()
if err != nil {
log.Errorf("error getting selected candidate pair: %v", err)
return
}
if pair == nil {
log.Errorf("pair is nil")
return
}
if pair.Remote == nil {
log.Errorf("pair.Remote is nil")
return
}
if pair.Remote.Address() == "" {
log.Errorf("address is empty")
return
}
}

View File

@@ -29,6 +29,7 @@ type ICEConnInfo struct {
LocalIceCandidateEndpoint string LocalIceCandidateEndpoint string
Relayed bool Relayed bool
RelayedOnLocal bool RelayedOnLocal bool
Agent *ice.Agent
} }
type WorkerICECallbacks struct { type WorkerICECallbacks struct {
@@ -126,6 +127,9 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("failed to dial the remote peer: %s", err) w.log.Debugf("failed to dial the remote peer: %s", err)
return return
} }
w.log.Infof("check remoteConn: %v", remoteConn)
w.log.Infof("check remoteConn.RemoteAddr: %v", remoteConn.RemoteAddr())
nilCheck(w.log, remoteConn)
w.log.Debugf("agent dial succeeded") w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair() pair, err := w.agent.GetSelectedCandidatePair()
@@ -154,6 +158,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Relayed: isRelayed(pair), Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
Agent: agent,
} }
w.log.Debugf("on ICE conn read to use ready") w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci) go w.conn.OnConnReady(w.selectedPriority, ci)
@@ -322,8 +327,10 @@ func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool
func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key isControlling := w.config.LocalKey > w.config.Key
if isControlling { if isControlling {
w.log.Infof("dialing remote peer %s as controlling", w.config.Key)
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else { } else {
w.log.Infof("dialing remote peer %s as controlled", w.config.Key)
return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} }
} }

View File

@@ -873,7 +873,7 @@ services:
zitadel: zitadel:
restart: 'always' restart: 'always'
networks: [netbird] networks: [netbird]
image: 'ghcr.io/zitadel/zitadel:v2.54.10' image: 'ghcr.io/zitadel/zitadel:v2.64.1'
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
env_file: env_file:
- ./zitadel.env - ./zitadel.env

View File

@@ -153,6 +153,7 @@ type AccountManager interface {
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
} }
type DefaultAccountManager struct { type DefaultAccountManager struct {

View File

@@ -1010,7 +1010,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
return return
} }
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
expectedSetupKey := setupKey.Key
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey, Key: expectedPeerKey,
@@ -1035,10 +1034,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String()) t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String())
} }
if peer.SetupKey != expectedSetupKey {
t.Errorf("expecting just added peer to have SetupKey = %s, got %s", expectedSetupKey, peer.SetupKey)
}
if account.Network.CurrentSerial() != 1 { if account.Network.CurrentSerial() != 1 {
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial()) t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
} }
@@ -2367,7 +2362,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
LoginExpired: false, LoginExpired: false,
}, },
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
SetupKey: "key",
}, },
"peer-2": { "peer-2": {
Status: &nbpeer.PeerStatus{ Status: &nbpeer.PeerStatus{
@@ -2375,7 +2369,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
LoginExpired: false, LoginExpired: false,
}, },
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
SetupKey: "key",
}, },
}, },
expiration: time.Second, expiration: time.Second,
@@ -2529,7 +2522,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
LoginExpired: false, LoginExpired: false,
}, },
InactivityExpirationEnabled: true, InactivityExpirationEnabled: true,
SetupKey: "key",
}, },
"peer-2": { "peer-2": {
Status: &nbpeer.PeerStatus{ Status: &nbpeer.PeerStatus{
@@ -2537,7 +2529,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
LoginExpired: false, LoginExpired: false,
}, },
InactivityExpirationEnabled: true, InactivityExpirationEnabled: true,
SetupKey: "key",
}, },
}, },
expiration: time.Second, expiration: time.Second,

View File

@@ -146,6 +146,8 @@ const (
AccountPeerInactivityExpirationEnabled Activity = 65 AccountPeerInactivityExpirationEnabled Activity = 65
AccountPeerInactivityExpirationDisabled Activity = 66 AccountPeerInactivityExpirationDisabled Activity = 66
AccountPeerInactivityExpirationDurationUpdated Activity = 67 AccountPeerInactivityExpirationDurationUpdated Activity = 67
SetupKeyDeleted Activity = 68
) )
var activityMap = map[Activity]Code{ var activityMap = map[Activity]Code{
@@ -219,6 +221,7 @@ var activityMap = map[Activity]Code{
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -530,10 +530,9 @@ components:
type: string type: string
example: reusable example: reusable
expires_in: expires_in:
description: Expiration time in seconds description: Expiration time in seconds, 0 will mean the key never expires
type: integer type: integer
minimum: 86400 minimum: 0
maximum: 31536000
example: 86400 example: 86400
revoked: revoked:
description: Setup key revocation status description: Setup key revocation status
@@ -2018,6 +2017,32 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Setup Key
description: Delete a Setup Key
tags: [ Setup Keys ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: keyId
required: true
schema:
type: string
description: The unique identifier of a setup key
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/groups: /api/groups:
get: get:
summary: List all Groups summary: List all Groups

View File

@@ -1101,7 +1101,7 @@ type SetupKeyRequest struct {
// Ephemeral Indicate that the peer will be ephemeral or not // Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral *bool `json:"ephemeral,omitempty"` Ephemeral *bool `json:"ephemeral,omitempty"`
// ExpiresIn Expiration time in seconds // ExpiresIn Expiration time in seconds, 0 will mean the key never expires
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
// Name Setup Key name // Name Setup Key name

View File

@@ -141,6 +141,7 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() {
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
} }
func (apiHandler *apiHandler) addPoliciesEndpoint() { func (apiHandler *apiHandler) addPoliciesEndpoint() {

View File

@@ -13,12 +13,13 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -168,7 +169,6 @@ func TestGetPeers(t *testing.T) {
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{
ID: testPeerID, ID: testPeerID,
Key: "key", Key: "key",
SetupKey: "setupkey",
IP: net.ParseIP("100.64.0.1"), IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true}, Status: &nbpeer.PeerStatus{Connected: true},
Name: "PeerName", Name: "PeerName",

View File

@@ -61,10 +61,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
expiresIn := time.Duration(req.ExpiresIn) * time.Second expiresIn := time.Duration(req.ExpiresIn) * time.Second
day := time.Hour * 24 if expiresIn < 0 {
year := day * 365 util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn can not be in the past"), w)
if expiresIn < day || expiresIn > year {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
return return
} }
@@ -76,6 +74,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil { if req.Ephemeral != nil {
ephemeral = *req.Ephemeral ephemeral = *req.Ephemeral
} }
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userID, ephemeral) req.AutoGroups, req.UsageLimit, userID, ephemeral)
if err != nil { if err != nil {
@@ -83,7 +82,11 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
return return
} }
writeSuccess(r.Context(), w, setupKey) apiSetupKeys := toResponseBody(setupKey)
// for the creation we need to send the plain key
apiSetupKeys.Key = setupKey.Key
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
} }
// GetSetupKey is a GET request to get a SetupKey by ID // GetSetupKey is a GET request to get a SetupKey by ID
@@ -98,7 +101,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["keyId"] keyID := vars["keyId"]
if len(keyID) == 0 { if len(keyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
return return
} }
@@ -123,7 +126,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["keyId"] keyID := vars["keyId"]
if len(keyID) == 0 { if len(keyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
return return
} }
@@ -181,6 +184,30 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
util.WriteJSONObject(r.Context(), w, apiSetupKeys) util.WriteJSONObject(r.Context(), w, apiSetupKeys)
} }
func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
return
}
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200) w.WriteHeader(200)
@@ -206,7 +233,7 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey {
return &api.SetupKey{ return &api.SetupKey{
Id: key.Id, Id: key.Id,
Key: key.Key, Key: key.KeySecret,
Name: key.Name, Name: key.Name,
Expires: key.ExpiresAt, Expires: key.ExpiresAt,
Type: string(key.Type), Type: string(key.Type),

View File

@@ -67,6 +67,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil return []*server.SetupKey{defaultKey}, nil
}, },
DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error {
if keyID == defaultKey.Id {
return nil
}
return status.Errorf(status.NotFound, "key %s not found", keyID)
},
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
@@ -81,18 +88,21 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
} }
func TestSetupKeysHandlers(t *testing.T) { func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey() defaultSetupKey, _ := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
server.SetupKeyUnlimitedUsage, true) server.SetupKeyUnlimitedUsage, true)
newSetupKey.Key = plainKey
updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Name = updatedSetupKeyName
updatedDefaultSetupKey.Revoked = true updatedDefaultSetupKey.Revoked = true
expectedNewKey := toResponseBody(newSetupKey)
expectedNewKey.Key = plainKey
tt := []struct { tt := []struct {
name string name string
requestType string requestType string
@@ -134,7 +144,7 @@ func TestSetupKeysHandlers(t *testing.T) {
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))), []byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedSetupKey: toResponseBody(newSetupKey), expectedSetupKey: expectedNewKey,
}, },
{ {
name: "Update Setup Key", name: "Update Setup Key",
@@ -150,6 +160,14 @@ func TestSetupKeysHandlers(t *testing.T) {
expectedBody: true, expectedBody: true,
expectedSetupKey: toResponseBody(updatedDefaultSetupKey), expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
}, },
{
name: "Delete Setup Key",
requestType: http.MethodDelete,
requestPath: "/api/setup-keys/" + defaultSetupKey.Id,
requestBody: bytes.NewBuffer([]byte("")),
expectedStatus: http.StatusOK,
expectedBody: false,
},
} }
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser) handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
@@ -164,6 +182,7 @@ func TestSetupKeysHandlers(t *testing.T) {
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -267,7 +267,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
peersSSHEnabled++ peersSSHEnabled++
} }
if peer.SetupKey == "" { if peer.UserID != "" {
userPeers++ userPeers++
} }

View File

@@ -2,13 +2,16 @@ package migration
import ( import (
"context" "context"
"crypto/sha256"
"database/sql" "database/sql"
b64 "encoding/base64"
"encoding/gob" "encoding/gob"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"unicode/utf8"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
@@ -205,3 +208,90 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi
return nil return nil
} }
func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error {
oldColumnName := "key"
newColumnName := "key_secret"
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if err := db.Transaction(func(tx *gorm.DB) error {
if !tx.Migrator().HasColumn(&model, newColumnName) {
log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", newColumnName, tableName)
if err := tx.Migrator().AddColumn(&model, newColumnName); err != nil {
return fmt.Errorf("add column %s: %w", newColumnName, err)
}
}
var rows []map[string]any
if err := tx.Table(tableName).
Select("id", oldColumnName, newColumnName).
Where(newColumnName + " IS NULL OR " + newColumnName + " = ''").
Where("SUBSTR(" + oldColumnName + ", 9, 1) = '-'").
Find(&rows).Error; err != nil {
return fmt.Errorf("find rows with empty secret key and matching pattern: %w", err)
}
if len(rows) == 0 {
log.WithContext(ctx).Infof("No plain setup keys found in table %s, no migration needed", tableName)
return nil
}
for _, row := range rows {
var plainKey string
if columnValue := row[oldColumnName]; columnValue != nil {
value, ok := columnValue.(string)
if !ok {
return fmt.Errorf("type assertion failed")
}
plainKey = value
}
secretKey := hiddenKey(plainKey, 4)
hashedKey := sha256.Sum256([]byte(plainKey))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, secretKey).Error; err != nil {
return fmt.Errorf("update row with secret key: %w", err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(oldColumnName, encodedHashedKey).Error; err != nil {
return fmt.Errorf("update row with hashed key: %w", err)
}
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil {
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
}
return nil
}); err != nil {
return err
}
log.Printf("Migration of plain setup key to hashed setup key completed")
return nil
}
// hiddenKey returns the Key value hidden with "*" and a 5 character prefix.
// E.g., "831F6*******************************"
func hiddenKey(key string, length int) string {
prefix := key[0:5]
if length > utf8.RuneCountInString(key) {
length = utf8.RuneCountInString(key) - len(prefix)
}
return prefix + strings.Repeat("*", length)
}

View File

@@ -160,3 +160,72 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr) db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged") assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
} }
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.SetupKey{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.SetupKey{
Id: "1",
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
}).Error
require.NoError(t, err, "Failed to insert setup key")
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
require.NoError(t, err, "Migration should not fail to migrate setup key")
var key server.SetupKey
err = db.Model(&server.SetupKey{}).First(&key).Error
assert.NoError(t, err, "Failed to fetch setup key")
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
}
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.SetupKey{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.SetupKey{
Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
KeySecret: "EEFDA****",
}).Error
require.NoError(t, err, "Failed to insert setup key")
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
require.NoError(t, err, "Migration should not fail to migrate setup key")
var key server.SetupKey
err = db.Model(&server.SetupKey{}).First(&key).Error
assert.NoError(t, err, "Failed to fetch setup key")
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
}
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.SetupKey{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.SetupKey{
Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
}).Error
require.NoError(t, err, "Failed to insert setup key")
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
require.NoError(t, err, "Migration should not fail to migrate setup key")
var key server.SetupKey
err = db.Model(&server.SetupKey{}).First(&key).Error
assert.NoError(t, err, "Failed to fetch setup key")
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
}

View File

@@ -109,6 +109,14 @@ type MockAccountManager struct {
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
}
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
if am.DeleteSetupKeyFunc != nil {
return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID)
}
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
} }
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {

View File

@@ -2,6 +2,8 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
b64 "encoding/base64"
"fmt" "fmt"
"net" "net"
"slices" "slices"
@@ -396,6 +398,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
upperKey := strings.ToUpper(setupKey) upperKey := strings.ToUpper(setupKey)
hashedKey := sha256.Sum256([]byte(upperKey))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
var accountID string var accountID string
var err error var err error
addedByUser := false addedByUser := false
@@ -403,7 +407,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser = true addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(userID) accountID, err = am.Store.GetAccountIDByUserID(userID)
} else { } else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey) accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
} }
if err != nil { if err != nil {
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
@@ -448,7 +452,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
opEvent.Activity = activity.PeerAddedByUser opEvent.Activity = activity.PeerAddedByUser
} else { } else {
// Validate the setup key // Validate the setup key
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey) sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to get setup key: %w", err) return fmt.Errorf("failed to get setup key: %w", err)
} }
@@ -489,7 +493,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
ID: xid.New().String(), ID: xid.New().String(),
AccountID: accountID, AccountID: accountID,
Key: peer.Key, Key: peer.Key,
SetupKey: upperKey,
IP: freeIP, IP: freeIP,
Meta: peer.Meta, Meta: peer.Meta,
Name: peer.Meta.Hostname, Name: peer.Meta.Hostname,

View File

@@ -16,8 +16,6 @@ type Peer struct {
AccountID string `json:"-" gorm:"index"` AccountID string `json:"-" gorm:"index"`
// WireGuard public key // WireGuard public key
Key string `gorm:"index"` Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string `diff:"-"`
// IP address of the Peer // IP address of the Peer
IP net.IP `gorm:"serializer:json"` IP net.IP `gorm:"serializer:json"`
// Meta is a Peer system meta data // Meta is a Peer system meta data
@@ -175,7 +173,6 @@ func (p *Peer) Copy() *Peer {
ID: p.ID, ID: p.ID,
AccountID: p.AccountID, AccountID: p.AccountID,
Key: p.Key, Key: p.Key,
SetupKey: p.SetupKey,
IP: p.IP, IP: p.IP,
Meta: p.Meta, Meta: p.Meta,
Name: p.Name, Name: p.Name,

View File

@@ -2,6 +2,8 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
b64 "encoding/base64"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -1090,7 +1092,6 @@ func Test_RegisterPeerByUser(t *testing.T) {
ID: xid.New().String(), ID: xid.New().String(),
AccountID: existingAccountID, AccountID: existingAccountID,
Key: "newPeerKey", Key: "newPeerKey",
SetupKey: "",
IP: net.IP{123, 123, 123, 123}, IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer", Hostname: "newPeer",
@@ -1155,7 +1156,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
ID: xid.New().String(), ID: xid.New().String(),
AccountID: existingAccountID, AccountID: existingAccountID,
Key: "newPeerKey", Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "", UserID: "",
IP: net.IP{123, 123, 123, 123}, IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
@@ -1175,7 +1175,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
account, err := store.GetAccount(context.Background(), existingAccountID) account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err) require.NoError(t, err)
@@ -1187,8 +1186,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes) hashedKey := sha256.Sum256([]byte(existingSetupKeyID))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
assert.NotEqual(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed)
assert.Equal(t, 1, account.SetupKeys[encodedHashedKey].UsedTimes)
} }
@@ -1221,7 +1223,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
ID: xid.New().String(), ID: xid.New().String(),
AccountID: existingAccountID, AccountID: existingAccountID,
Key: "newPeerKey", Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "", UserID: "",
IP: net.IP{123, 123, 123, 123}, IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
@@ -1250,8 +1251,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) hashedKey := sha256.Sum256([]byte(faultyKey))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed.UTC())
assert.Equal(t, 0, account.SetupKeys[encodedHashedKey].UsedTimes)
} }
func TestPeerAccountPeersUpdate(t *testing.T) { func TestPeerAccountPeersUpdate(t *testing.T) {

View File

@@ -2,6 +2,9 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/fnv" "hash/fnv"
"strconv" "strconv"
"strings" "strings"
@@ -73,6 +76,7 @@ type SetupKey struct {
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"` AccountID string `json:"-" gorm:"index"`
Key string Key string
KeySecret string
Name string Name string
Type SetupKeyType Type SetupKeyType
CreatedAt time.Time CreatedAt time.Time
@@ -104,6 +108,7 @@ func (key *SetupKey) Copy() *SetupKey {
Id: key.Id, Id: key.Id,
AccountID: key.AccountID, AccountID: key.AccountID,
Key: key.Key, Key: key.Key,
KeySecret: key.KeySecret,
Name: key.Name, Name: key.Name,
Type: key.Type, Type: key.Type,
CreatedAt: key.CreatedAt, CreatedAt: key.CreatedAt,
@@ -120,19 +125,17 @@ func (key *SetupKey) Copy() *SetupKey {
// EventMeta returns activity event meta related to the setup key // EventMeta returns activity event meta related to the setup key
func (key *SetupKey) EventMeta() map[string]any { func (key *SetupKey) EventMeta() map[string]any {
return map[string]any{"name": key.Name, "type": key.Type, "key": key.HiddenCopy(1).Key} return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret}
} }
// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix. // hiddenKey returns the Key value hidden with "*" and a 5 character prefix.
// E.g., "831F6*******************************" // E.g., "831F6*******************************"
func (key *SetupKey) HiddenCopy(length int) *SetupKey { func hiddenKey(key string, length int) string {
k := key.Copy() prefix := key[0:5]
prefix := k.Key[0:5] if length > utf8.RuneCountInString(key) {
if length > utf8.RuneCountInString(key.Key) { length = utf8.RuneCountInString(key) - len(prefix)
length = utf8.RuneCountInString(key.Key) - len(prefix)
} }
k.Key = prefix + strings.Repeat("*", length) return prefix + strings.Repeat("*", length)
return k
} }
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now // IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
@@ -155,6 +158,9 @@ func (key *SetupKey) IsRevoked() bool {
// IsExpired if key was expired // IsExpired if key was expired
func (key *SetupKey) IsExpired() bool { func (key *SetupKey) IsExpired() bool {
if key.ExpiresAt.IsZero() {
return false
}
return time.Now().After(key.ExpiresAt) return time.Now().After(key.ExpiresAt)
} }
@@ -169,30 +175,40 @@ func (key *SetupKey) IsOverUsed() bool {
// GenerateSetupKey generates a new setup key // GenerateSetupKey generates a new setup key
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string,
usageLimit int, ephemeral bool) *SetupKey { usageLimit int, ephemeral bool) (*SetupKey, string) {
key := strings.ToUpper(uuid.New().String()) key := strings.ToUpper(uuid.New().String())
limit := usageLimit limit := usageLimit
if t == SetupKeyOneOff { if t == SetupKeyOneOff {
limit = 1 limit = 1
} }
expiresAt := time.Time{}
if validFor != 0 {
expiresAt = time.Now().UTC().Add(validFor)
}
hashedKey := sha256.Sum256([]byte(key))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
return &SetupKey{ return &SetupKey{
Id: strconv.Itoa(int(Hash(key))), Id: strconv.Itoa(int(Hash(key))),
Key: key, Key: encodedHashedKey,
KeySecret: hiddenKey(key, 4),
Name: name, Name: name,
Type: t, Type: t,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(validFor), ExpiresAt: expiresAt,
UpdatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
Revoked: false, Revoked: false,
UsedTimes: 0, UsedTimes: 0,
AutoGroups: autoGroups, AutoGroups: autoGroups,
UsageLimit: limit, UsageLimit: limit,
Ephemeral: ephemeral, Ephemeral: ephemeral,
} }, key
} }
// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration // GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration
func GenerateDefaultSetupKey() *SetupKey { func GenerateDefaultSetupKey() (*SetupKey, string) {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
SetupKeyUnlimitedUsage, false) SetupKeyUnlimitedUsage, false)
} }
@@ -213,11 +229,6 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 {
keyDuration = expiresIn
}
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -227,7 +238,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
return nil, err return nil, err
} }
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
@@ -246,6 +257,9 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
} }
} }
// for the creation return the plain key to the caller
setupKey.Key = plainKey
return setupKey, nil return setupKey, nil
} }
@@ -334,7 +348,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") return nil, status.NewUnauthorizedToViewSetupKeysError()
} }
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
@@ -342,18 +356,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, err return nil, err
} }
keys := make([]*SetupKey, 0, len(setupKeys)) return setupKeys, nil
for _, key := range setupKeys {
var k *SetupKey
if !user.IsAdminOrServiceUser() {
k = key.HiddenCopy(999)
} else {
k = key.Copy()
}
keys = append(keys, k)
}
return keys, nil
} }
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -364,7 +367,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") return nil, status.NewUnauthorizedToViewSetupKeysError()
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
@@ -377,11 +380,33 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
setupKey.UpdatedAt = setupKey.CreatedAt setupKey.UpdatedAt = setupKey.CreatedAt
} }
if !user.IsAdminOrServiceUser() { return setupKey, nil
setupKey = setupKey.HiddenCopy(999)
} }
return setupKey, nil // DeleteSetupKey removes the setup key from the account
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return status.NewUnauthorizedToViewSetupKeysError()
}
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
if err != nil {
return fmt.Errorf("failed to delete setup key: %w", err)
}
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
return nil
} }
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {

View File

@@ -2,8 +2,11 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@@ -66,7 +69,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
} }
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
key.Id, time.Now().UTC(), autoGroups) key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked)
@@ -183,7 +186,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))),
tCase.expectedUpdatedAt, tCase.expectedGroups) tCase.expectedUpdatedAt, tCase.expectedGroups, false)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated)
@@ -239,10 +242,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) {
expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour)
var expectedAutoGroups []string var expectedAutoGroups []string
key := GenerateDefaultSetupKey() key, plainKey := GenerateDefaultSetupKey()
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true)
} }
@@ -256,41 +259,41 @@ func TestGenerateSetupKey(t *testing.T) {
expectedUpdatedAt := time.Now().UTC() expectedUpdatedAt := time.Now().UTC()
var expectedAutoGroups []string var expectedAutoGroups []string
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true)
} }
func TestSetupKey_IsValid(t *testing.T) { func TestSetupKey_IsValid(t *testing.T) {
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
if !validKey.IsValid() { if !validKey.IsValid() {
t.Errorf("expected key to be valid, got invalid %v", validKey) t.Errorf("expected key to be valid, got invalid %v", validKey)
} }
// expired // expired
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
if expiredKey.IsValid() { if expiredKey.IsValid() {
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
} }
// revoked // revoked
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
revokedKey.Revoked = true revokedKey.Revoked = true
if revokedKey.IsValid() { if revokedKey.IsValid() {
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
} }
// overused // overused
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
overUsedKey.UsedTimes = 1 overUsedKey.UsedTimes = 1
if overUsedKey.IsValid() { if overUsedKey.IsValid() {
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
} }
// overused // overused
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
reusableKey.UsedTimes = 99 reusableKey.UsedTimes = 99
if !reusableKey.IsValid() { if !reusableKey.IsValid() {
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
@@ -299,7 +302,7 @@ func TestSetupKey_IsValid(t *testing.T) {
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string,
expectedUpdatedAt time.Time, expectedAutoGroups []string) { expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) {
t.Helper() t.Helper()
if key.Name != expectedName { if key.Name != expectedName {
t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name) t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name)
@@ -329,13 +332,23 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt) t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt)
} }
if expectHashedKey {
if !isValidBase64SHA256(key.Key) {
t.Errorf("expected key to be hashed, got %v", key.Key)
}
} else {
_, err := uuid.Parse(key.Key) _, err := uuid.Parse(key.Key)
if err != nil { if err != nil {
t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err) t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err)
} }
}
if key.Id != strconv.Itoa(int(Hash(key.Key))) { if !strings.HasSuffix(key.KeySecret, "****") {
t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id) t.Errorf("expected key secret to be secure, got %v", key.Key)
}
if key.Id != expectedID {
t.Errorf("expected key Id %v, got %v", expectedID, key.Id)
} }
if len(key.AutoGroups) != len(expectedAutoGroups) { if len(key.AutoGroups) != len(expectedAutoGroups) {
@@ -344,13 +357,26 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal") assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal")
} }
func isValidBase64SHA256(encodedKey string) bool {
decoded, err := base64.StdEncoding.DecodeString(encodedKey)
if err != nil {
return false
}
if len(decoded) != sha256.Size {
return false
}
return true
}
func TestSetupKey_Copy(t *testing.T) { func TestSetupKey_Copy(t *testing.T) {
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
keyCopy := key.Copy() keyCopy := key.Copy()
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id,
key.UpdatedAt, key.AutoGroups) key.UpdatedAt, key.AutoGroups, true)
} }

View File

@@ -469,7 +469,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey var key SetupKey
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey)) result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -741,7 +741,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var accountID string var accountID string
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID) result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -973,7 +973,7 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key)) First(&setupKey, keyQueryCondition, key)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found") return nil, status.Errorf(status.NotFound, "setup key not found")
@@ -1232,6 +1232,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
} }
func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID)
}
// getRecords retrieves records from the database based on the account ID. // getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T var record []T
@@ -1264,3 +1268,21 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
} }
return &record, nil return &record, nil
} }
// deleteRecordByID deletes a record by its ID and account ID from the database.
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
var record T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "record not found")
}
return nil
}

View File

@@ -2,6 +2,8 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
b64 "encoding/base64"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
@@ -71,7 +73,7 @@ func runLargeTest(t *testing.T, store Store) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
setupKey := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 6000 const numPerAccount = 6000
for n := 0; n < numPerAccount; n++ { for n := 0; n < numPerAccount; n++ {
@@ -81,7 +83,6 @@ func runLargeTest(t *testing.T, store Store) {
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{
ID: peerID, ID: peerID,
Key: peerID, Key: peerID,
SetupKey: "",
IP: netIP, IP: netIP,
Name: peerID, Name: peerID,
DNSLabel: peerID, DNSLabel: peerID,
@@ -133,7 +134,7 @@ func runLargeTest(t *testing.T, store Store) {
} }
account.NameServerGroups[nameserver.ID] = nameserver account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
} }
@@ -215,11 +216,10 @@ func TestSqlite_SaveAccount(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
@@ -230,11 +230,10 @@ func TestSqlite_SaveAccount(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey = GenerateDefaultSetupKey() setupKey, _ = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{ account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2", Key: "peerkey2",
SetupKey: "peerkeysetupkey2",
IP: net.IP{127, 0, 0, 2}, IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2", Name: "peer name 2",
@@ -297,11 +296,10 @@ func TestSqlite_DeleteAccount(t *testing.T) {
}} }}
account := newAccountWithId(context.Background(), "account_id", testUserID, "") account := newAccountWithId(context.Background(), "account_id", testUserID, "")
setupKey := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
@@ -396,7 +394,6 @@ func TestSqlite_SavePeer(t *testing.T) {
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
ID: "testpeer", ID: "testpeer",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
Name: "peer name", Name: "peer name",
@@ -455,7 +452,6 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
account.Peers["testpeer"] = &nbpeer.Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
ID: "testpeer", ID: "testpeer",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
@@ -720,11 +716,10 @@ func newSqliteStore(t *testing.T) *SqlStore {
func newAccount(store Store, id int) error { func newAccount(store Store, id int) error {
str := fmt.Sprintf("%s-%d", uuid.New().String(), id) str := fmt.Sprintf("%s-%d", uuid.New().String(), id)
account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com")
setupKey := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["p"+str] = &nbpeer.Peer{ account.Peers["p"+str] = &nbpeer.Peer{
Key: "peerkey" + str, Key: "peerkey" + str,
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
@@ -760,11 +755,10 @@ func TestPostgresql_SaveAccount(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
@@ -775,11 +769,10 @@ func TestPostgresql_SaveAccount(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey = GenerateDefaultSetupKey() setupKey, _ = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{ account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2", Key: "peerkey2",
SetupKey: "peerkeysetupkey2",
IP: net.IP{127, 0, 0, 2}, IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2", Name: "peer name 2",
@@ -842,11 +835,10 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
}} }}
account := newAccountWithId(context.Background(), "account_id", testUserID, "") account := newAccountWithId(context.Background(), "account_id", testUserID, "")
setupKey := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
@@ -923,7 +915,6 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
account.Peers["testpeer"] = &nbpeer.Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
ID: "testpeer", ID: "testpeer",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
@@ -1118,12 +1109,17 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
hashedKey := sha256.Sum256([]byte(plainKey))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
_, err = store.GetAccount(context.Background(), existingAccountID) _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err) require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key) assert.Equal(t, encodedHashedKey, setupKey.Key)
assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
assert.Equal(t, "Default key", setupKey.Name) assert.Equal(t, "Default key", setupKey.Name)
} }
@@ -1138,24 +1134,28 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
hashedKey := sha256.Sum256([]byte(plainKey))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
_, err = store.GetAccount(context.Background(), existingAccountID) _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err) require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, setupKey.UsedTimes) assert.Equal(t, 0, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err) require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, setupKey.UsedTimes) assert.Equal(t, 1, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err) require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes) assert.Equal(t, 2, setupKey.UsedTimes)
} }
@@ -1264,3 +1264,32 @@ func TestSqlite_GetGroupByName(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "All", group.Name) require.Equal(t, "All", group.Name)
} }
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID)
require.NoError(t, err)
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
require.Error(t, err)
}
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nonExistingKeyID := "non-existing-key-id"
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID)
require.Error(t, err)
}

View File

@@ -114,3 +114,13 @@ func NewGetAccountFromStoreError(err error) error {
func NewGetUserFromStoreError() error { func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store") return Errorf(Internal, "issue getting user from store")
} }
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID")
}
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key
func NewUnauthorizedToViewSetupKeysError() error {
return Errorf(Unauthorized, "only users with admin power can view setup keys")
}

View File

@@ -124,6 +124,7 @@ type Store interface {
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
} }
type StoreEngine string type StoreEngine string
@@ -241,6 +242,9 @@ func getMigrations(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error { func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip")
}, },
func(db *gorm.DB) error {
return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db)
},
} }
} }