Release 0.28.0 (#2092)

* compile client under freebsd (#1620)

Compile netbird client under freebsd and now support netstack and userspace modes.
Refactoring linux specific code to share same code with FreeBSD, move to *_unix.go files.

Not implemented yet:

Kernel mode not supported
DNS probably does not work yet
Routing also probably does not work yet
SSH support did not tested yet
Lack of test environment for freebsd (dedicated VM for github runners under FreeBSD required)
Lack of tests for freebsd specific code
info reporting need to review and also implement, for example OS reported as GENERIC instead of FreeBSD (lack of FreeBSD icon in management interface)
Lack of proper client setup under FreeBSD
Lack of FreeBSD port/package

* Add DNS routes (#1943)

Given domains are resolved periodically and resolved IPs are replaced with the new ones. Unless the flag keep_route is set to true, then only new ones are added.
This option is helpful if there are long-running connections that might still point to old IP addresses from changed DNS records.

* Add process posture check (#1693)

Introduces a process posture check to validate the existence and active status of specific binaries on peer systems. The check ensures that files are present at specified paths, and that corresponding processes are running. This check supports Linux, Windows, and macOS systems.


Co-authored-by: Evgenii <mail@skillcoder.com>
Co-authored-by: Pascal Fischer <pascal@netbird.io>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Maycon Santos
2024-06-13 13:24:24 +02:00
committed by GitHub
parent 95299be52d
commit 4fec709bb1
149 changed files with 6509 additions and 2710 deletions

View File

@@ -12,12 +12,13 @@ import (
type Client interface {
io.Closer
Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap() (*proto.NetworkMap, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy() bool
SyncMeta(sysInfo *system.Info) error
}

View File

@@ -257,7 +257,7 @@ func TestClient_Sync(t *testing.T) {
ch := make(chan *mgmtProto.SyncResponse, 1)
go func() {
err = client.Sync(context.Background(), func(msg *mgmtProto.SyncResponse) error {
err = client.Sync(context.Background(), info, func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})

View File

@@ -29,6 +29,11 @@ import (
const ConnectTimeout = 10 * time.Second
const (
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
errMsgNoMgmtConnection = "no connection to management"
)
// ConnStateNotifier is a wrapper interface of the status recorders
type ConnStateNotifier interface {
MarkManagementDisconnected(error)
@@ -113,13 +118,11 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error {
backOff := defaultBackoff(ctx)
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState()
if connState == connectivity.Shutdown {
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
@@ -129,55 +132,60 @@ func (c *GrpcClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncRe
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
return err
}
log.Infof("connected to the Management Service stream")
c.notifyConnected()
// blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
if err != nil {
s, _ := gstatus.FromError(err)
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
backOff.Reset() // reset backoff counter after successful connection
c.notifyDisconnected(err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
}
return nil
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
}
err := backoff.Retry(operation, backOff)
err := backoff.Retry(operation, defaultBackoff(ctx))
if err != nil {
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
}
return err
}
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
msgHandler func(msg *proto.SyncResponse) error) error {
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
return err
}
log.Infof("connected to the Management Service stream")
c.notifyConnected()
// blocking until error
err = c.receiveEvents(stream, serverPubKey, msgHandler)
if err != nil {
s, _ := gstatus.FromError(err)
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
c.notifyDisconnected(err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
}
return nil
}
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
@@ -186,7 +194,7 @@ func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
@@ -219,8 +227,8 @@ func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
return decryptedResp.GetNetworkMap(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
@@ -269,7 +277,7 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
return nil, fmt.Errorf(errMsgNoMgmtConnection)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
@@ -316,7 +324,7 @@ func (c *GrpcClient) IsHealthy() bool {
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
return nil, fmt.Errorf(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
@@ -431,6 +439,35 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
return flowInfoResp, nil
}
// SyncMeta sends updated system metadata to the Management Service.
// It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
if !c.ready() {
return fmt.Errorf(errMsgNoMgmtConnection)
}
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
syncMetaReq, err := encryption.EncryptMessage(*serverPubKey, c.key, &proto.SyncMetaRequest{Meta: infoToMetaData(sysInfo)})
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: syncMetaReq,
})
return err
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
@@ -464,6 +501,15 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
})
}
files := make([]*proto.File, 0, len(info.Files))
for _, file := range info.Files {
files = append(files, &proto.File{
Path: file.Path,
Exist: file.Exist,
ProcessIsRunning: file.ProcessIsRunning,
})
}
return &proto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
@@ -483,5 +529,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
Cloud: info.Environment.Cloud,
Platform: info.Environment.Platform,
},
Files: files,
}
}

View File

@@ -11,12 +11,13 @@ import (
type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
}
func (m *MockClient) IsHealthy() bool {
@@ -30,11 +31,11 @@ func (m *MockClient) Close() error {
return m.CloseFunc()
}
func (m *MockClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error {
func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
if m.SyncFunc == nil {
return nil
}
return m.SyncFunc(ctx, msgHandler)
return m.SyncFunc(ctx, sysInfo, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
@@ -73,6 +74,13 @@ func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
}
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
func (m *MockClient) GetNetworkMap() (*proto.NetworkMap, error) {
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
return nil, nil
}
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
if m.SyncMetaFunc == nil {
return nil
}
return m.SyncMetaFunc(sysInfo)
}

View File

@@ -0,0 +1,34 @@
package domain
import (
"golang.org/x/net/idna"
)
type Domain string
// String converts the Domain to a non-punycode string.
func (d Domain) String() (string, error) {
unicode, err := idna.ToUnicode(string(d))
if err != nil {
return "", err
}
return unicode, nil
}
// SafeString converts the Domain to a non-punycode string, falling back to the original string if conversion fails.
func (d Domain) SafeString() string {
str, err := d.String()
if err != nil {
str = string(d)
}
return str
}
// FromString creates a Domain from a string, converting it to punycode.
func FromString(s string) (Domain, error) {
ascii, err := idna.ToASCII(s)
if err != nil {
return "", err
}
return Domain(ascii), nil
}

83
management/domain/list.go Normal file
View File

@@ -0,0 +1,83 @@
package domain
import "strings"
type List []Domain
// ToStringList converts a List to a slice of string.
func (d List) ToStringList() ([]string, error) {
var list []string
for _, domain := range d {
s, err := domain.String()
if err != nil {
return nil, err
}
list = append(list, s)
}
return list, nil
}
// ToPunycodeList converts the List to a slice of Punycode-encoded domain strings.
func (d List) ToPunycodeList() []string {
var list []string
for _, domain := range d {
list = append(list, string(domain))
}
return list
}
// ToSafeStringList converts the List to a slice of non-punycode strings.
// If a domain cannot be converted, the original string is used.
func (d List) ToSafeStringList() []string {
var list []string
for _, domain := range d {
list = append(list, domain.SafeString())
}
return list
}
// String converts List to a comma-separated string.
func (d List) String() (string, error) {
list, err := d.ToStringList()
if err != nil {
return "", err
}
return strings.Join(list, ", "), nil
}
// SafeString converts List to a comma-separated non-punycode string.
// If a domain cannot be converted, the original string is used.
func (d List) SafeString() string {
str, err := d.String()
if err != nil {
return strings.Join(d.ToPunycodeList(), ", ")
}
return str
}
// PunycodeString converts the List to a comma-separated string of Punycode-encoded domains.
func (d List) PunycodeString() string {
return strings.Join(d.ToPunycodeList(), ", ")
}
// FromStringList creates a DomainList from a slice of string.
func FromStringList(s []string) (List, error) {
var dl List
for _, domain := range s {
d, err := FromString(domain)
if err != nil {
return nil, err
}
dl = append(dl, d)
}
return dl, nil
}
// FromPunycodeList creates a List from a slice of Punycode-encoded domain strings.
func FromPunycodeList(s []string) List {
var dl List
for _, domain := range s {
dl = append(dl, Domain(domain))
}
return dl
}

File diff suppressed because it is too large Load Diff

View File

@@ -38,6 +38,12 @@ service ManagementService {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
rpc GetPKCEAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
rpc SyncMeta(EncryptedMessage) returns (Empty) {}
}
message EncryptedMessage {
@@ -50,7 +56,10 @@ message EncryptedMessage {
int32 version = 3;
}
message SyncRequest {}
message SyncRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
// SyncResponse represents a state that should be applied to the local peer (e.g. Wiretrustee servers config as well as local peer and remote peers configs)
message SyncResponse {
@@ -69,6 +78,14 @@ message SyncResponse {
bool remotePeersIsEmpty = 4;
NetworkMap NetworkMap = 5;
// Posture checks to be evaluated by client
repeated Checks Checks = 6;
}
message SyncMetaRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
message LoginRequest {
@@ -82,6 +99,7 @@ message LoginRequest {
PeerKeys peerKeys = 4;
}
// PeerKeys is additional peer info like SSH pub key and WireGuard public key.
// This message is sent on Login or register requests, or when a key rotation has to happen.
message PeerKeys {
@@ -100,6 +118,16 @@ message Environment {
string platform = 2;
}
// File represents a file on the system.
message File {
// path is the path to the file.
string path = 1;
// exist indicate whether the file exists.
bool exist = 2;
// processIsRunning indicates whether the file is a running process or not.
bool processIsRunning = 3;
}
// PeerSystemMeta is machine meta data like OS and version.
message PeerSystemMeta {
string hostname = 1;
@@ -117,6 +145,7 @@ message PeerSystemMeta {
string sysProductName = 13;
string sysManufacturer = 14;
Environment environment = 15;
repeated File files = 16;
}
message LoginResponse {
@@ -124,6 +153,8 @@ message LoginResponse {
WiretrusteeConfig wiretrusteeConfig = 1;
// Peer local config
PeerConfig peerConfig = 2;
// Posture checks to be evaluated by client
repeated Checks Checks = 3;
}
message ServerKeyResponse {
@@ -303,6 +334,8 @@ message Route {
int64 Metric = 5;
bool Masquerade = 6;
string NetID = 7;
repeated string Domains = 8;
bool keepRoute = 9;
}
// DNSConfig represents a dns.Update
@@ -371,3 +404,7 @@ message NetworkAddress {
string netIP = 1;
string mac = 2;
}
message Checks {
repeated string Files= 1;
}

View File

@@ -43,6 +43,11 @@ type ManagementServiceClient interface {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
}
type managementServiceClient struct {
@@ -130,6 +135,15 @@ func (c *managementServiceClient) GetPKCEAuthorizationFlow(ctx context.Context,
return out, nil
}
func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/SyncMeta", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// ManagementServiceServer is the server API for ManagementService service.
// All implementations must embed UnimplementedManagementServiceServer
// for forward compatibility
@@ -159,6 +173,11 @@ type ManagementServiceServer interface {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
mustEmbedUnimplementedManagementServiceServer()
}
@@ -184,6 +203,9 @@ func (UnimplementedManagementServiceServer) GetDeviceAuthorizationFlow(context.C
func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
}
func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -308,6 +330,24 @@ func _ManagementService_GetPKCEAuthorizationFlow_Handler(srv interface{}, ctx co
return interceptor(ctx, in, info, handler)
}
func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).SyncMeta(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/SyncMeta",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).SyncMeta(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -335,6 +375,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetPKCEAuthorizationFlow",
Handler: _ManagementService_GetPKCEAuthorizationFlow_Handler,
},
{
MethodName: "SyncMeta",
Handler: _ManagementService_SyncMeta_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -20,6 +20,7 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -102,7 +103,7 @@ type AccountManager interface {
DeletePolicy(accountID, policyID, userID string) error
ListPolicies(accountID, userID string) ([]*Policy, error)
GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error)
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error
DeleteRoute(accountID string, routeID route.ID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error)
@@ -117,6 +118,7 @@ type AccountManager interface {
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
@@ -131,8 +133,9 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
GroupValidation(accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
CancelPeerRoutines(peer *nbpeer.Peer) error
SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
}
@@ -275,7 +278,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{}
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
groupListMap := a.getPeerGroups(peerID)
@@ -293,7 +296,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[string(route.GetHAUniqueID(r))]
_, found := peerMemberships[string(r.GetHAUniqueID())]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
@@ -376,11 +379,13 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
return enabledRoutes, disabledRoutes
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
var routes []*route.Route
for _, r := range a.Routes {
if r.Network.String() == prefix.String() {
if r.IsDynamic() && r.Domains.PunycodeString() == domains.PunycodeString() {
routes = append(routes, r)
} else if r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
@@ -1850,7 +1855,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
}
}
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
@@ -1867,7 +1872,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.I
return nil, nil, err
}
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
if err != nil {
return nil, nil, err
}
@@ -1906,6 +1911,27 @@ func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
}
func (am *DefaultAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountReadLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
_, _, err = am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account)
if err != nil {
return mapError(err)
}
return nil
}
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil

View File

@@ -1435,7 +1435,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
},
}
routes := account.GetRoutesByPrefix(prefix)
routes := account.GetRoutesByPrefixOrDomains(prefix, nil)
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)

View File

@@ -11,6 +11,7 @@ import (
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/management/server/posture"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
@@ -134,7 +135,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return err
}
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
if syncReq.GetMeta() == nil {
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
if err != nil {
return mapError(err)
}
@@ -157,12 +162,15 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
// keep a connection to the peer and send updates when available
return s.handleUpdates(peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
case update, open := <-updates:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
}
@@ -174,21 +182,10 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
}
log.Debugf("received an update for peer %s", peerKey.String())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed processing update message")
if err := s.sendUpdate(peerKey, peer, update, srv); err != nil {
return err
}
err = srv.SendMsg(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.Debugf("sent an update to peer %s", peerKey.String())
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
@@ -199,6 +196,26 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
}
}
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.Debugf("sent an update to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
@@ -250,14 +267,18 @@ func mapError(err error) error {
return status.Errorf(codes.Internal, "failed handling request")
}
func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
osVersion := loginReq.GetMeta().GetOSVersion()
if osVersion == "" {
osVersion = loginReq.GetMeta().GetCore()
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
if meta == nil {
return nbpeer.PeerSystemMeta{}
}
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(loginReq.GetMeta().GetNetworkAddresses()))
for _, addr := range loginReq.GetMeta().GetNetworkAddresses() {
osVersion := meta.GetOSVersion()
if osVersion == "" {
osVersion = meta.GetCore()
}
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(meta.GetNetworkAddresses()))
for _, addr := range meta.GetNetworkAddresses() {
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
if err != nil {
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
@@ -269,24 +290,34 @@ func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
})
}
files := make([]nbpeer.File, 0, len(meta.GetFiles()))
for _, file := range meta.GetFiles() {
files = append(files, nbpeer.File{
Path: file.GetPath(),
Exist: file.GetExist(),
ProcessIsRunning: file.GetProcessIsRunning(),
})
}
return nbpeer.PeerSystemMeta{
Hostname: loginReq.GetMeta().GetHostname(),
GoOS: loginReq.GetMeta().GetGoOS(),
Kernel: loginReq.GetMeta().GetKernel(),
Platform: loginReq.GetMeta().GetPlatform(),
OS: loginReq.GetMeta().GetOS(),
Hostname: meta.GetHostname(),
GoOS: meta.GetGoOS(),
Kernel: meta.GetKernel(),
Platform: meta.GetPlatform(),
OS: meta.GetOS(),
OSVersion: osVersion,
WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(),
UIVersion: loginReq.GetMeta().GetUiVersion(),
KernelVersion: loginReq.GetMeta().GetKernelVersion(),
WtVersion: meta.GetWiretrusteeVersion(),
UIVersion: meta.GetUiVersion(),
KernelVersion: meta.GetKernelVersion(),
NetworkAddresses: networkAddresses,
SystemSerialNumber: loginReq.GetMeta().GetSysSerialNumber(),
SystemProductName: loginReq.GetMeta().GetSysProductName(),
SystemManufacturer: loginReq.GetMeta().GetSysManufacturer(),
SystemSerialNumber: meta.GetSysSerialNumber(),
SystemProductName: meta.GetSysProductName(),
SystemManufacturer: meta.GetSysManufacturer(),
Environment: nbpeer.Environment{
Cloud: loginReq.GetMeta().GetEnvironment().GetCloud(),
Platform: loginReq.GetMeta().GetEnvironment().GetPlatform(),
Cloud: meta.GetEnvironment().GetCloud(),
Platform: meta.GetEnvironment().GetPlatform(),
},
Files: files,
}
}
@@ -335,24 +366,11 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, msg
}
userID := ""
// JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered,
// or it uses a setup key to register.
if loginReq.GetJwtToken() != "" {
for i := 0; i < 3; i++ {
userID, err = s.validateToken(loginReq.GetJwtToken())
if err == nil {
break
}
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
"Trying again as it may be due to the IdP cache issue", peerKey, err)
time.Sleep(200 * time.Millisecond)
}
if err != nil {
return nil, err
}
userID, err := s.processJwtToken(loginReq, peerKey)
if err != nil {
return nil, err
}
var sshKey []byte
if loginReq.GetPeerKeys() != nil {
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
@@ -361,12 +379,11 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq),
Meta: extractPeerMeta(loginReq.GetMeta()),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
})
if err != nil {
log.Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(err)
@@ -381,6 +398,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(s.accountManager, peerKey.String()),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
@@ -394,6 +412,31 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
}, nil
}
// processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if
// the token is valid.
//
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
// registered or if it uses a setup key to register.
func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
userID := ""
if loginReq.GetJwtToken() != "" {
var err error
for i := 0; i < 3; i++ {
userID, err = s.validateToken(loginReq.GetJwtToken())
if err == nil {
break
}
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
time.Sleep(200 * time.Millisecond)
}
if err != nil {
return "", err
}
}
return userID, nil
}
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
switch configProto {
case UDP:
@@ -477,7 +520,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -508,6 +551,7 @@ func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCred
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
Checks: toProtocolChecks(accountManager, peer.Key),
}
}
@@ -526,7 +570,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
@@ -643,3 +687,67 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
Body: encryptedResp,
}, nil
}
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
// peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx)
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncMetaReq := &proto.SyncMetaRequest{}
peerKey, err := s.parseRequest(req, syncMetaReq)
if err != nil {
return nil, err
}
if syncMetaReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg)
return nil, msg
}
err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta()))
if err != nil {
return nil, mapError(err)
}
return &proto.Empty{}, nil
}
// toProtocolChecks returns posture checks for the peer that needs to be evaluated on the client side.
func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Checks {
postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey)
if err != nil {
log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err)
return nil
}
protoChecks := make([]*proto.Checks, 0)
for _, postureCheck := range postureChecks {
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
}
return protoChecks
}
// toProtocolCheck converts a posture.Checks to a proto.Checks.
func toProtocolCheck(postureCheck posture.Checks) *proto.Checks {
protoCheck := &proto.Checks{}
if check := postureCheck.Checks.ProcessCheck; check != nil {
for _, process := range check.Processes {
if process.LinuxPath != "" {
protoCheck.Files = append(protoCheck.Files, process.LinuxPath)
}
if process.MacPath != "" {
protoCheck.Files = append(protoCheck.Files, process.MacPath)
}
if process.WindowsPath != "" {
protoCheck.Files = append(protoCheck.Files, process.WindowsPath)
}
}
}
return protoCheck
}

View File

@@ -817,6 +817,8 @@ components:
$ref: '#/components/schemas/GeoLocationCheck'
peer_network_range_check:
$ref: '#/components/schemas/PeerNetworkRangeCheck'
process_check:
$ref: '#/components/schemas/ProcessCheck'
NBVersionCheck:
description: Posture check for the version of NetBird
type: object
@@ -905,6 +907,32 @@ components:
required:
- ranges
- action
ProcessCheck:
description: Posture Check for binaries exist and are running in the peers system
type: object
properties:
processes:
type: array
items:
$ref: '#/components/schemas/Process'
required:
- processes
Process:
description: Describes the operational activity within a peer's system.
type: object
properties:
linux_path:
description: Path to the process executable file in a Linux operating system
type: string
example: "/usr/local/bin/netbird"
mac_path:
description: Path to the process executable file in a Mac operating system
type: string
example: "/Applications/NetBird.app/Contents/MacOS/netbird"
windows_path:
description: Path to the process executable file in a Windows operating system
type: string
example: "C:\ProgramData\NetBird\netbird.exe"
Location:
description: Describe geographical location information
type: object
@@ -995,9 +1023,17 @@ components:
type: string
example: chacbco6lnnbn6cg5s91
network:
description: Network range in CIDR format
description: Network range in CIDR format, Conflicts with domains
type: string
example: 10.64.0.0/24
domains:
description: Domain list to be dynamically resolved. Conflicts with network
type: array
items:
type: string
minLength: 1
maxLength: 255
example: "example.com"
metric:
description: Route metric number. Lowest number has higher priority
type: integer
@@ -1014,6 +1050,10 @@ components:
items:
type: string
example: "chacdk86lnnboviihd70"
keep_route:
description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore
type: boolean
example: true
required:
- id
- description
@@ -1022,10 +1062,13 @@ components:
# Only one property has to be set
#- peer
#- peer_groups
- network
# Only one property has to be set
#- network
#- domains
- metric
- masquerade
- groups
- keep_route
Route:
allOf:
- type: object
@@ -1035,7 +1078,7 @@ components:
type: string
example: chacdk86lnnboviihd7g
network_type:
description: Network type indicating if it is IPv4 or IPv6
description: Network type indicating if it is a domain route or a IPv4/IPv6 route
type: string
example: IPv4
required:

View File

@@ -225,6 +225,9 @@ type Checks struct {
// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"`
// ProcessCheck Posture Check for binaries exist and are running in the peers system
ProcessCheck *ProcessCheck `json:"process_check,omitempty"`
}
// City Describe city geographical location information
@@ -949,11 +952,31 @@ type PostureCheckUpdate struct {
Name string `json:"name"`
}
// Process Describes the operational activity within a peer's system.
type Process struct {
// LinuxPath Path to the process executable file in a Linux operating system
LinuxPath *string `json:"linux_path,omitempty"`
// MacPath Path to the process executable file in a Mac operating system
MacPath *string `json:"mac_path,omitempty"`
// WindowsPath Path to the process executable file in a Windows operating system
WindowsPath *string `json:"windows_path,omitempty"`
}
// ProcessCheck Posture Check for binaries exist and are running in the peers system
type ProcessCheck struct {
Processes []Process `json:"processes"`
}
// Route defines model for Route.
type Route struct {
// Description Route description
Description string `json:"description"`
// Domains Domain list to be dynamically resolved. Conflicts with network
Domains *[]string `json:"domains,omitempty"`
// Enabled Route status
Enabled bool `json:"enabled"`
@@ -963,19 +986,22 @@ type Route struct {
// Id Route Id
Id string `json:"id"`
// KeepRoute Indicate if the route should be kept after a domain doesn't resolve that IP anymore
KeepRoute bool `json:"keep_route"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"`
// Metric Route metric number. Lowest number has higher priority
Metric int `json:"metric"`
// Network Network range in CIDR format
Network string `json:"network"`
// Network Network range in CIDR format, Conflicts with domains
Network *string `json:"network,omitempty"`
// NetworkId Route network identifier, to group HA routes
NetworkId string `json:"network_id"`
// NetworkType Network type indicating if it is IPv4 or IPv6
// NetworkType Network type indicating if it is a domain route or a IPv4/IPv6 route
NetworkType string `json:"network_type"`
// Peer Peer Identifier associated with route. This property can not be set together with `peer_groups`
@@ -990,20 +1016,26 @@ type RouteRequest struct {
// Description Route description
Description string `json:"description"`
// Domains Domain list to be dynamically resolved. Conflicts with network
Domains *[]string `json:"domains,omitempty"`
// Enabled Route status
Enabled bool `json:"enabled"`
// Groups Group IDs containing routing peers
Groups []string `json:"groups"`
// KeepRoute Indicate if the route should be kept after a domain doesn't resolve that IP anymore
KeepRoute bool `json:"keep_route"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"`
// Metric Route metric number. Lowest number has higher priority
Metric int `json:"metric"`
// Network Network range in CIDR format
Network string `json:"network"`
// Network Network range in CIDR format, Conflicts with domains
Network *string `json:"network,omitempty"`
// NetworkId Route network identifier, to group HA routes
NetworkId string `json:"network_id"`

View File

@@ -2,6 +2,7 @@ package http
import (
"net/http"
"regexp"
"github.com/gorilla/mux"
@@ -13,6 +14,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// GeolocationsHandler is a handler that returns locations.
type GeolocationsHandler struct {
accountManager server.AccountManager
@@ -73,8 +78,8 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}

View File

@@ -27,12 +27,12 @@ const (
notFoundUserID = "notFoundUserID"
existingTokenID = "existingTokenID"
notFoundTokenID = "notFoundTokenID"
domain = "hotmail.com"
testDomain = "hotmail.com"
)
var testAccount = &server.Account{
Id: existingAccountID,
Domain: domain,
Domain: testDomain,
Users: map[string]*server.User{
existingUserID: {
Id: existingUserID,
@@ -117,7 +117,7 @@ func initPATTestData() *PATHandler {
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: domain,
Domain: testDomain,
AccountId: testNSGroupAccountID,
}
}),

View File

@@ -3,8 +3,6 @@ package http
import (
"encoding/json"
"net/http"
"regexp"
"slices"
"github.com/gorilla/mux"
@@ -17,10 +15,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// PostureChecksHandler is a handler that returns posture checks of the account.
type PostureChecksHandler struct {
accountManager server.AccountManager
@@ -163,23 +157,20 @@ func (p *PostureChecksHandler) savePostureChecks(
user *server.User,
postureChecksID string,
) {
var (
err error
req api.PostureCheckUpdate
)
var req api.PostureCheckUpdate
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := validatePostureChecksUpdate(req)
if err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
}
@@ -197,69 +188,3 @@ func (p *PostureChecksHandler) savePostureChecks(
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
}
func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
if req.Name == "" {
return status.Errorf(status.InvalidArgument, "posture checks name shouldn't be empty")
}
if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil &&
req.Checks.GeoLocationCheck == nil && req.Checks.PeerNetworkRangeCheck == nil) {
return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty")
}
if req.Checks.NbVersionCheck != nil && req.Checks.NbVersionCheck.MinVersion == "" {
return status.Errorf(status.InvalidArgument, "minimum version for NetBird's version check shouldn't be empty")
}
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
emptyOS := osVersionCheck.Android == nil && osVersionCheck.Darwin == nil && osVersionCheck.Ios == nil &&
osVersionCheck.Linux == nil && osVersionCheck.Windows == nil
emptyMinVersion := osVersionCheck.Android != nil && osVersionCheck.Android.MinVersion == "" ||
osVersionCheck.Darwin != nil && osVersionCheck.Darwin.MinVersion == "" ||
osVersionCheck.Ios != nil && osVersionCheck.Ios.MinVersion == "" ||
osVersionCheck.Linux != nil && osVersionCheck.Linux.MinKernelVersion == "" ||
osVersionCheck.Windows != nil && osVersionCheck.Windows.MinKernelVersion == ""
if emptyOS || emptyMinVersion {
return status.Errorf(status.InvalidArgument,
"minimum version for at least one OS in the OS version check shouldn't be empty")
}
}
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if geoLocationCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for geolocation check shouldn't be empty")
}
allowedActions := []api.GeoLocationCheckAction{api.GeoLocationCheckActionAllow, api.GeoLocationCheckActionDeny}
if !slices.Contains(allowedActions, geoLocationCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for geolocation check is not valid value")
}
if len(geoLocationCheck.Locations) == 0 {
return status.Errorf(status.InvalidArgument, "locations for geolocation check shouldn't be empty")
}
for _, loc := range geoLocationCheck.Locations {
if loc.CountryCode == "" {
return status.Errorf(status.InvalidArgument, "country code for geolocation check shouldn't be empty")
}
if !countryCodeRegex.MatchString(loc.CountryCode) {
return status.Errorf(status.InvalidArgument, "country code must be 2 letters (ISO 3166-1 alpha-2 format)")
}
}
}
if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
if peerNetworkRangeCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
}
allowedActions := []api.PeerNetworkRangeCheckAction{api.PeerNetworkRangeCheckActionAllow, api.PeerNetworkRangeCheckActionDeny}
if !slices.Contains(allowedActions, peerNetworkRangeCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for peer network range check is not valid value")
}
if len(peerNetworkRangeCheck.Ranges) == 0 {
return status.Errorf(status.InvalidArgument, "network ranges for peer network range check shouldn't be empty")
}
}
return nil
}

View File

@@ -43,6 +43,11 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error())
}
return nil
},
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
@@ -433,6 +438,45 @@ func TestPostureCheckUpdate(t *testing.T) {
handler.geolocationManager = nil
},
},
{
name: "Create Posture Checks Process Check",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"process_check": {
"processes": [
{
"linux_path": "/usr/local/bin/netbird",
"mac_path": "/Applications/NetBird.app/Contents/MacOS/netbird",
"windows_path": "C:\\ProgramData\\NetBird\\netbird.exe"
}
]
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str("default"),
Checks: api.Checks{
ProcessCheck: &api.ProcessCheck{
Processes: []api.Process{
{
LinuxPath: str("/usr/local/bin/netbird"),
MacPath: str("/Applications/NetBird.app/Contents/MacOS/netbird"),
WindowsPath: str("C:\\ProgramData\\NetBird\\netbird.exe"),
},
},
},
},
},
},
{
name: "Create Posture Checks Invalid Check",
requestType: http.MethodPost,
@@ -446,7 +490,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -461,7 +505,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -475,7 +519,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -489,7 +533,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"geo_location_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -663,11 +707,8 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Invalid Check",
@@ -682,7 +723,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -697,7 +738,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -711,7 +752,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -841,100 +882,3 @@ func TestPostureCheckUpdate(t *testing.T) {
})
}
}
func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) {
// empty name
err := validatePostureChecksUpdate(api.PostureCheckUpdate{})
assert.Error(t, err)
// empty checks
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default"})
assert.Error(t, err)
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{}})
assert.Error(t, err)
// not valid NbVersionCheck
nbVersionCheck := api.NBVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.Error(t, err)
// valid NbVersionCheck
nbVersionCheck = api.NBVersionCheck{MinVersion: "1.0"}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.NoError(t, err)
// not valid OsVersionCheck
osVersionCheck := api.OSVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}, Darwin: &api.MinVersionCheck{MinVersion: "14.2"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{
Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"},
Darwin: &api.MinVersionCheck{MinVersion: "14.2"},
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid peer network range check
peerNetworkRangeCheck := api.PeerNetworkRangeCheck{
Action: api.PeerNetworkRangeCheckActionAllow,
Ranges: []string{
"192.168.1.0/24", "10.0.0.0/8",
},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.NoError(t, err)
// invalid peer network range check
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: api.PeerNetworkRangeCheckActionDeny,
Ranges: []string{},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err)
// invalid peer network range check
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: "unknownAction",
Ranges: []string{},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err)
}

View File

@@ -2,11 +2,16 @@ package http
import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"regexp"
"strings"
"unicode/utf8"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
@@ -15,6 +20,9 @@ import (
"github.com/netbirdio/netbird/route"
)
const maxDomains = 32
const failedToConvertRoute = "failed to convert route to response: %v"
// RoutesHandler is the routes handler of the account
type RoutesHandler struct {
accountManager server.AccountManager
@@ -48,7 +56,12 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
}
apiRoutes := make([]*api.Route, 0)
for _, r := range routes {
apiRoutes = append(apiRoutes, toRouteResponse(r))
route, err := toRouteResponse(r)
if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
apiRoutes = append(apiRoutes, route)
}
util.WriteJSONObject(w, apiRoutes)
@@ -70,16 +83,28 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
return
}
_, newPrefix, err := route.ParseNetwork(req.Network)
if err != nil {
if err := h.validateRoute(req); err != nil {
util.WriteError(err, w)
return
}
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d",
route.MaxNetIDChar), w)
return
var domains domain.List
var networkType route.NetworkType
var newPrefix netip.Prefix
if req.Domains != nil {
d, err := validateDomains(*req.Domains)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
}
domains = d
networkType = route.DomainNetwork
} else if req.Network != nil {
networkType, newPrefix, err = route.ParseNetwork(*req.Network)
if err != nil {
util.WriteError(err, w)
return
}
}
peerId := ""
@@ -87,36 +112,57 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
peerId = *req.Peer
}
peerGroupIds := []string{}
var peerGroupIds []string
if req.PeerGroups != nil {
peerGroupIds = *req.PeerGroups
}
if (peerId != "" && len(peerGroupIds) > 0) || (peerId == "" && len(peerGroupIds) == 0) {
util.WriteError(status.Errorf(status.InvalidArgument, "only one peer or peer_groups should be provided"), w)
return
}
// do not allow non Linux peers
// Do not allow non-Linux peers
if peer := account.GetPeer(peerId); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
return
}
}
newRoute, err := h.accountManager.CreateRoute(
account.Id, newPrefix.String(), peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
)
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
if err != nil {
util.WriteError(err, w)
return
}
resp := toRouteResponse(newRoute)
routes, err := toRouteResponse(newRoute)
if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(w, routes)
}
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
if req.Network != nil && req.Domains != nil {
return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided")
}
if req.Network == nil && req.Domains == nil {
return status.Errorf(status.InvalidArgument, "either 'network' or 'domains' should be provided")
}
if req.Peer == nil && req.PeerGroups == nil {
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
}
if req.Peer != nil && req.PeerGroups != nil {
return status.Errorf(status.InvalidArgument, "only one of 'peer' or 'peer_groups' should be provided")
}
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d characters",
route.MaxNetIDChar)
}
return nil
}
// UpdateRoute handles update to a route identified by a given ID
@@ -148,26 +194,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return
}
prefixType, newPrefix, err := route.ParseNetwork(req.Network)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s",
req.Network, routeID), w)
return
}
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
util.WriteError(status.Errorf(status.InvalidArgument,
"identifier should be between 1 and %d", route.MaxNetIDChar), w)
return
}
if req.Peer != nil && req.PeerGroups != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "only peer or peers_group should be provided"), w)
return
}
if req.Peer == nil && req.PeerGroups == nil {
util.WriteError(status.Errorf(status.InvalidArgument, "either peer or peers_group should be provided"), w)
if err := h.validateRoute(req); err != nil {
util.WriteError(err, w)
return
}
@@ -186,14 +214,29 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute := &route.Route{
ID: route.ID(routeID),
Network: newPrefix,
NetID: route.NetID(req.NetworkId),
NetworkType: prefixType,
Masquerade: req.Masquerade,
Metric: req.Metric,
Description: req.Description,
Enabled: req.Enabled,
Groups: req.Groups,
KeepRoute: req.KeepRoute,
}
if req.Domains != nil {
d, err := validateDomains(*req.Domains)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
}
newRoute.Domains = d
newRoute.NetworkType = route.DomainNetwork
} else if req.Network != nil {
newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network)
if err != nil {
util.WriteError(err, w)
return
}
}
if req.Peer != nil {
@@ -210,9 +253,13 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return
}
resp := toRouteResponse(newRoute)
routes, err := toRouteResponse(newRoute)
if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(w, routes)
}
// DeleteRoute handles route deletion request
@@ -260,25 +307,69 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return
}
util.WriteJSONObject(w, toRouteResponse(foundRoute))
routes, err := toRouteResponse(foundRoute)
if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(w, routes)
}
func toRouteResponse(serverRoute *route.Route) *api.Route {
func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
domains, err := serverRoute.Domains.ToStringList()
if err != nil {
return nil, err
}
network := serverRoute.Network.String()
route := &api.Route{
Id: string(serverRoute.ID),
Description: serverRoute.Description,
NetworkId: string(serverRoute.NetID),
Enabled: serverRoute.Enabled,
Peer: &serverRoute.Peer,
Network: serverRoute.Network.String(),
Network: &network,
Domains: &domains,
NetworkType: serverRoute.NetworkType.String(),
Masquerade: serverRoute.Masquerade,
Metric: serverRoute.Metric,
Groups: serverRoute.Groups,
KeepRoute: serverRoute.KeepRoute,
}
if len(serverRoute.PeerGroups) > 0 {
route.PeerGroups = &serverRoute.PeerGroups
}
return route
return route, nil
}
// validateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func validateDomains(domains []string) (domain.List, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty")
}
if len(domains) > maxDomains {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}
domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
var domainList domain.List
for _, d := range domains {
d := strings.ToLower(d)
// handles length and idna conversion
punycode, err := domain.FromString(d)
if err != nil {
return domainList, fmt.Errorf("failed to convert domain to punycode: %s: %v", d, err)
}
if !domainRegex.MatchString(string(punycode)) {
return domainList, fmt.Errorf("invalid domain format: %s", d)
}
domainList = append(domainList, punycode)
}
return domainList, nil
}

View File

@@ -10,6 +10,8 @@ import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
@@ -18,6 +20,7 @@ import (
"github.com/gorilla/mux"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -26,6 +29,7 @@ import (
const (
existingRouteID = "existingRouteID"
existingRouteID2 = "existingRouteID2" // for peer_groups test
existingRouteID3 = "existingRouteID3" // for domains test
notFoundRouteID = "notFoundRouteID"
existingPeerIP1 = "100.64.0.100"
existingPeerIP2 = "100.64.0.101"
@@ -35,6 +39,7 @@ const (
testAccountID = "test_id"
existingGroupID = "testGroup"
notFoundGroupID = "nonExistingGroup"
existingDomain = "example.com"
)
var emptyString = ""
@@ -46,6 +51,8 @@ var baseExistingRoute = &route.Route{
Description: "base route",
NetID: "awesomeNet",
Network: netip.MustParsePrefix("192.168.0.0/24"),
Domains: domain.List{},
KeepRoute: false,
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -90,28 +97,33 @@ func initRoutesTestData() *RoutesHandler {
route := baseExistingRoute.Copy()
route.PeerGroups = []string{existingGroupID}
return route, nil
} else if routeID == existingRouteID3 {
route := baseExistingRoute.Copy()
route.Domains = domain.List{existingDomain}
return route, nil
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
},
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
CreateRouteFunc: func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
}
networkType, p, _ := route.ParseNetwork(network)
return &route.Route{
ID: existingRouteID,
NetID: netID,
Peer: peerID,
PeerGroups: peerGroups,
Network: p,
Network: prefix,
Domains: domains,
NetworkType: networkType,
Description: description,
Masquerade: masquerade,
Enabled: enabled,
Groups: groups,
KeepRoute: keepRoute,
}, nil
},
SaveRouteFunc: func(_, _ string, r *route.Route) error {
@@ -146,6 +158,9 @@ func TestRoutesHandlers(t *testing.T) {
baseExistingRouteWithPeerGroups := baseExistingRoute.Copy()
baseExistingRouteWithPeerGroups.PeerGroups = []string{existingGroupID}
baseExistingRouteWithDomains := baseExistingRoute.Copy()
baseExistingRouteWithDomains.Domains = domain.List{existingDomain}
tt := []struct {
name string
expectedStatus int
@@ -161,7 +176,7 @@ func TestRoutesHandlers(t *testing.T) {
requestPath: "/api/routes/" + existingRouteID,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: toRouteResponse(baseExistingRoute),
expectedRoute: toApiRoute(t, baseExistingRoute),
},
{
name: "Get Not Existing Route",
@@ -175,7 +190,15 @@ func TestRoutesHandlers(t *testing.T) {
requestPath: "/api/routes/" + existingRouteID2,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: toRouteResponse(baseExistingRouteWithPeerGroups),
expectedRoute: toApiRoute(t, baseExistingRouteWithPeerGroups),
},
{
name: "Get Existing Route with Domains",
requestType: http.MethodGet,
requestPath: "/api/routes/" + existingRouteID3,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: toApiRoute(t, baseExistingRouteWithDomains),
},
{
name: "Delete Existing Route",
@@ -191,18 +214,18 @@ func TestRoutesHandlers(t *testing.T) {
expectedStatus: http.StatusNotFound,
},
{
name: "POST OK",
name: "Network POST OK",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID))),
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: "192.168.0.0/16",
Network: toPtr("192.168.0.0/16"),
Peer: &existingPeerID,
NetworkType: route.IPv4NetworkString,
Masquerade: false,
@@ -210,6 +233,28 @@ func TestRoutesHandlers(t *testing.T) {
Groups: []string{existingGroupID},
},
},
{
name: "Domains POST OK",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "domainNet",
Network: toPtr("invalid Prefix"),
KeepRoute: true,
Domains: &[]string{existingDomain},
Peer: &existingPeerID,
NetworkType: route.DomainNetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
},
},
{
name: "POST Non Linux Peer",
requestType: http.MethodPost,
@@ -242,6 +287,32 @@ func TestRoutesHandlers(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "POST Invalid Domains",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["-example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "POST UnprocessableEntity when both network and domains are provided",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","domains":["example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "POST UnprocessableEntity when no network and domains are provided",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf(`{"Description":"Post","network_id":"awesomeNet","groups":["%s"]}`, existingPeerID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "POST UnprocessableEntity when both peer and peer_groups are provided",
requestType: http.MethodPost,
@@ -261,7 +332,7 @@ func TestRoutesHandlers(t *testing.T) {
expectedBody: false,
},
{
name: "PUT OK",
name: "Network PUT OK",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID)),
@@ -271,7 +342,7 @@ func TestRoutesHandlers(t *testing.T) {
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: "192.168.0.0/16",
Network: toPtr("192.168.0.0/16"),
Peer: &existingPeerID,
NetworkType: route.IPv4NetworkString,
Masquerade: false,
@@ -279,6 +350,27 @@ func TestRoutesHandlers(t *testing.T) {
Groups: []string{existingGroupID},
},
},
{
name: "Domains PUT OK",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: toPtr("invalid Prefix"),
Domains: &[]string{existingDomain},
Peer: &existingPeerID,
NetworkType: route.DomainNetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
KeepRoute: true,
},
},
{
name: "PUT OK when peer_groups provided",
requestType: http.MethodPut,
@@ -290,7 +382,7 @@ func TestRoutesHandlers(t *testing.T) {
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: "192.168.0.0/16",
Network: toPtr("192.168.0.0/16"),
Peer: &emptyString,
PeerGroups: &[]string{existingGroupID},
NetworkType: route.IPv4NetworkString,
@@ -339,6 +431,33 @@ func TestRoutesHandlers(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "PUT Invalid Domains",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf(`{"Description":"Post","domains":["-example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "PUT UnprocessableEntity when both network and domains are provided",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","domains":["example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "PUT UnprocessableEntity when no network and domains are provided",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf(`{"Description":"Post","network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "PUT UnprocessableEntity when both peer and peer_groups are provided",
requestType: http.MethodPut,
@@ -399,3 +518,85 @@ func TestRoutesHandlers(t *testing.T) {
})
}
}
func TestValidateDomains(t *testing.T) {
tests := []struct {
name string
domains []string
expected domain.List
wantErr bool
}{
{
name: "Empty list",
domains: nil,
expected: nil,
wantErr: true,
},
{
name: "Valid ASCII domain",
domains: []string{"sub.ex-ample.com"},
expected: domain.List{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Valid Unicode domain",
domains: []string{"münchen.de"},
expected: domain.List{"xn--mnchen-3ya.de"},
wantErr: false,
},
{
name: "Valid Unicode, all labels",
domains: []string{"中国.中国.中国"},
expected: domain.List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"},
wantErr: false,
},
{
name: "With underscores",
domains: []string{"_jabber._tcp.gmail.com"},
expected: domain.List{"_jabber._tcp.gmail.com"},
wantErr: false,
},
{
name: "Invalid domain format",
domains: []string{"-example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid domain format 2",
domains: []string{"example.com-"},
expected: nil,
wantErr: true,
},
{
name: "Multiple domains valid and invalid",
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
expected: domain.List{"google.com"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := validateDomains(tt.domains)
assert.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, got, tt.expected)
})
}
}
func toApiRoute(t *testing.T, r *route.Route) *api.Route {
t.Helper()
apiRoute, err := toRouteResponse(r)
// json flattens pointer to nil slices to null
if apiRoute.Domains != nil && *apiRoute.Domains == nil {
apiRoute.Domains = nil
}
require.NoError(t, err, "Failed to convert route")
return apiRoute
}
func toPtr[T any](v T) *T {
return &v
}

View File

@@ -27,7 +27,7 @@ const (
var usersTestAccount = &server.Account{
Id: existingAccountID,
Domain: domain,
Domain: testDomain,
Users: map[string]*server.User{
existingUserID: {
Id: existingUserID,
@@ -127,7 +127,7 @@ func initUsersTestData() *UsersHandler {
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: domain,
Domain: testDomain,
AccountId: existingAccountID,
}
}),

View File

@@ -134,7 +134,8 @@ func Test_SyncProtocol(t *testing.T) {
// take the first registered peer as a base for the test. Total four.
key := *peers[0]
message, err := encryption.EncryptMessage(*serverKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message, err := encryption.EncryptMessage(*serverKey, key, syncReq)
if err != nil {
t.Fatal(err)
return

View File

@@ -93,7 +93,8 @@ var _ = Describe("Management service", func() {
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq)
Expect(err).NotTo(HaveOccurred())
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
@@ -143,7 +144,7 @@ var _ = Describe("Management service", func() {
loginPeerWithValidSetupKey(serverPubKey, key1, client)
loginPeerWithValidSetupKey(serverPubKey, key2, client)
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
Expect(err).NotTo(HaveOccurred())
@@ -176,7 +177,7 @@ var _ = Describe("Management service", func() {
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
Expect(err).NotTo(HaveOccurred())
@@ -329,7 +330,7 @@ var _ = Describe("Management service", func() {
var clients []mgmtProto.ManagementService_SyncClient
for _, peer := range peers {
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer)
Expect(err).NotTo(HaveOccurred())
@@ -394,7 +395,8 @@ var _ = Describe("Management service", func() {
defer GinkgoRecover()
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq)
Expect(err).NotTo(HaveOccurred())
// open stream

View File

@@ -2,12 +2,14 @@ package mock_server
import (
"net"
"net/netip"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group"
@@ -28,7 +30,7 @@ type MockAccountManager struct {
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
@@ -52,7 +54,7 @@ type MockAccountManager struct {
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
CreateRouteFunc func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
@@ -81,6 +83,7 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
GetPeerAppliedPostureChecksFunc func(peerKey string) ([]posture.Checks, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error)
@@ -95,12 +98,13 @@ type MockAccountManager struct {
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error)
SyncPeerMetaFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(peerPubKey, realIP)
return am.SyncAndMarkPeerFunc(peerPubKey, meta, realIP)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
@@ -413,9 +417,9 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
func (am *MockAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID)
return am.CreateRouteFunc(accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
@@ -623,6 +627,14 @@ func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer
return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
}
// GetPeerAppliedPostureChecks mocks GetPeerAppliedPostureChecks of the AccountManager interface
func (am *MockAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
if am.GetPeerAppliedPostureChecksFunc != nil {
return am.GetPeerAppliedPostureChecksFunc(peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerAppliedPostureChecks is not implemented")
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
if am.UpdateAccountSettingsFunc != nil {
@@ -736,6 +748,14 @@ func (am *MockAccountManager) GroupValidation(accountId string, groups []string)
return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented")
}
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
func (am *MockAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error {
if am.SyncPeerMetaFunc != nil {
return am.SyncPeerMetaFunc(peerPubKey, meta)
}
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
}
// FindExistingPostureCheck mocks FindExistingPostureCheck of the AccountManager interface
func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
if am.FindExistingPostureCheckFunc != nil {

View File

@@ -3,9 +3,10 @@ package mock_server
import (
"context"
"github.com/netbirdio/netbird/management/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/proto"
)
type ManagementServiceServerMock struct {
@@ -17,6 +18,7 @@ type ManagementServiceServerMock struct {
IsHealthyFunc func(context.Context, *proto.Empty) (*proto.Empty, error)
GetDeviceAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
GetPKCEAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
SyncMetaFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error)
}
func (m ManagementServiceServerMock) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
@@ -60,3 +62,10 @@ func (m ManagementServiceServerMock) GetPKCEAuthorizationFlow(ctx context.Contex
}
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
}
func (m ManagementServiceServerMock) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
if m.SyncMetaFunc != nil {
return m.SyncMetaFunc(ctx, req)
}
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}

View File

@@ -19,6 +19,11 @@ import (
type PeerSync struct {
// WireGuardPubKey is a peers WireGuard public key
WireGuardPubKey string
// Meta is the system information passed by peer, must be always present
Meta nbpeer.PeerSystemMeta
// UpdateAccountPeers indicate updating account peers,
// which occurs when the peer's metadata is updated
UpdateAccountPeers bool
}
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
@@ -528,6 +533,18 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated {
err = am.Store.SaveAccount(account)
if err != nil {
return nil, nil, err
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(account)
}
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, err
@@ -900,7 +917,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
continue
}
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
update := toSyncResponse(am, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
)
@@ -79,6 +80,13 @@ type Environment struct {
Platform string
}
// File is a file on the system.
type File struct {
Path string
Exist bool
ProcessIsRunning bool
}
// PeerSystemMeta is a metadata of a Peer machine system
type PeerSystemMeta struct { //nolint:revive
Hostname string
@@ -96,24 +104,22 @@ type PeerSystemMeta struct { //nolint:revive
SystemProductName string
SystemManufacturer string
Environment Environment `gorm:"serializer:json"`
Files []File `gorm:"serializer:json"`
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
if len(p.NetworkAddresses) != len(other.NetworkAddresses) {
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
for _, addr := range p.NetworkAddresses {
var found bool
for _, oAddr := range other.NetworkAddresses {
if addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP {
found = true
continue
}
}
if !found {
return false
}
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
@@ -133,6 +139,26 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
p.Environment.Platform == other.Environment.Platform
}
func (p PeerSystemMeta) isEmpty() bool {
return p.Hostname == "" &&
p.GoOS == "" &&
p.Kernel == "" &&
p.Core == "" &&
p.Platform == "" &&
p.OS == "" &&
p.OSVersion == "" &&
p.WtVersion == "" &&
p.UIVersion == "" &&
p.KernelVersion == "" &&
len(p.NetworkAddresses) == 0 &&
p.SystemSerialNumber == "" &&
p.SystemProductName == "" &&
p.SystemManufacturer == "" &&
p.Environment.Cloud == "" &&
p.Environment.Platform == "" &&
len(p.Files) == 0
}
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
func (p *Peer) AddedWithSSOLogin() bool {
return p.UserID != ""
@@ -168,6 +194,10 @@ func (p *Peer) Copy() *Peer {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
if meta.isEmpty() {
return false
}
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion

View File

@@ -1,8 +1,9 @@
package posture
import (
"fmt"
"errors"
"net/netip"
"regexp"
"github.com/hashicorp/go-version"
"github.com/rs/xid"
@@ -17,15 +18,21 @@ const (
OSVersionCheckName = "OSVersionCheck"
GeoLocationCheckName = "GeoLocationCheck"
PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
ProcessCheckName = "ProcessCheck"
CheckActionAllow string = "allow"
CheckActionDeny string = "deny"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// Check represents an interface for performing a check on a peer.
type Check interface {
Check(peer nbpeer.Peer) (bool, error)
Name() string
Check(peer nbpeer.Peer) (bool, error)
Validate() error
}
type Checks struct {
@@ -51,6 +58,7 @@ type ChecksDefinition struct {
OSVersionCheck *OSVersionCheck `json:",omitempty"`
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
ProcessCheck *ProcessCheck `json:",omitempty"`
}
// Copy returns a copy of a checks definition.
@@ -96,6 +104,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
}
copy(cdCopy.PeerNetworkRangeCheck.Ranges, peerNetRangeCheck.Ranges)
}
if cd.ProcessCheck != nil {
processCheck := cd.ProcessCheck
cdCopy.ProcessCheck = &ProcessCheck{
Processes: make([]Process, len(processCheck.Processes)),
}
copy(cdCopy.ProcessCheck.Processes, processCheck.Processes)
}
return cdCopy
}
@@ -136,6 +151,9 @@ func (pc *Checks) GetChecks() []Check {
if pc.Checks.PeerNetworkRangeCheck != nil {
checks = append(checks, pc.Checks.PeerNetworkRangeCheck)
}
if pc.Checks.ProcessCheck != nil {
checks = append(checks, pc.Checks.ProcessCheck)
}
return checks
}
@@ -191,6 +209,10 @@ func buildPostureCheck(postureChecksID string, name string, description string,
}
}
if processCheck := checks.ProcessCheck; processCheck != nil {
postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck)
}
return &postureChecks, nil
}
@@ -221,6 +243,10 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck {
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(pc.Checks.PeerNetworkRangeCheck)
}
if pc.Checks.ProcessCheck != nil {
checks.ProcessCheck = toProcessCheckResponse(pc.Checks.ProcessCheck)
}
return &api.PostureCheck{
Id: pc.ID,
Name: pc.Name,
@@ -229,44 +255,20 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck {
}
}
// Validate checks the validity of a posture checks.
func (pc *Checks) Validate() error {
if check := pc.Checks.NBVersionCheck; check != nil {
if !isVersionValid(check.MinVersion) {
return fmt.Errorf("%s version: %s is not valid", check.Name(), check.MinVersion)
}
if pc.Name == "" {
return errors.New("posture checks name shouldn't be empty")
}
if osCheck := pc.Checks.OSVersionCheck; osCheck != nil {
if osCheck.Android != nil {
if !isVersionValid(osCheck.Android.MinVersion) {
return fmt.Errorf("%s android version: %s is not valid", osCheck.Name(), osCheck.Android.MinVersion)
}
}
checks := pc.GetChecks()
if len(checks) == 0 {
return errors.New("posture checks shouldn't be empty")
}
if osCheck.Ios != nil {
if !isVersionValid(osCheck.Ios.MinVersion) {
return fmt.Errorf("%s ios version: %s is not valid", osCheck.Name(), osCheck.Ios.MinVersion)
}
}
if osCheck.Darwin != nil {
if !isVersionValid(osCheck.Darwin.MinVersion) {
return fmt.Errorf("%s darwin version: %s is not valid", osCheck.Name(), osCheck.Darwin.MinVersion)
}
}
if osCheck.Linux != nil {
if !isVersionValid(osCheck.Linux.MinKernelVersion) {
return fmt.Errorf("%s linux kernel version: %s is not valid", osCheck.Name(),
osCheck.Linux.MinKernelVersion)
}
}
if osCheck.Windows != nil {
if !isVersionValid(osCheck.Windows.MinKernelVersion) {
return fmt.Errorf("%s windows kernel version: %s is not valid", osCheck.Name(),
osCheck.Windows.MinKernelVersion)
}
for _, check := range checks {
if err := check.Validate(); err != nil {
return err
}
}
@@ -352,3 +354,40 @@ func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*PeerNetworkRang
Action: string(check.Action),
}, nil
}
func toProcessCheckResponse(check *ProcessCheck) *api.ProcessCheck {
processes := make([]api.Process, 0, len(check.Processes))
for i := range check.Processes {
processes = append(processes, api.Process{
LinuxPath: &check.Processes[i].LinuxPath,
MacPath: &check.Processes[i].MacPath,
WindowsPath: &check.Processes[i].WindowsPath,
})
}
return &api.ProcessCheck{
Processes: processes,
}
}
func toProcessCheck(check *api.ProcessCheck) *ProcessCheck {
processes := make([]Process, 0, len(check.Processes))
for _, process := range check.Processes {
var p Process
if process.LinuxPath != nil {
p.LinuxPath = *process.LinuxPath
}
if process.MacPath != nil {
p.MacPath = *process.MacPath
}
if process.WindowsPath != nil {
p.WindowsPath = *process.WindowsPath
}
processes = append(processes, p)
}
return &ProcessCheck{
Processes: processes,
}
}

View File

@@ -150,9 +150,23 @@ func TestChecks_Validate(t *testing.T) {
checks Checks
expectedError bool
}{
{
name: "Empty name",
checks: Checks{},
expectedError: true,
},
{
name: "Empty checks",
checks: Checks{
Name: "Default",
Checks: ChecksDefinition{},
},
expectedError: true,
},
{
name: "Valid checks version",
checks: Checks{
Name: "default",
Checks: ChecksDefinition{
NBVersionCheck: &NBVersionCheck{
MinVersion: "0.25.0",
@@ -261,6 +275,14 @@ func TestChecks_Copy(t *testing.T) {
},
Action: CheckActionDeny,
},
ProcessCheck: &ProcessCheck{
Processes: []Process{
{
MacPath: "/Applications/NetBird.app/Contents/MacOS/netbird",
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
},
}
checkCopy := check.Copy()

View File

@@ -2,6 +2,7 @@ package posture
import (
"fmt"
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -60,3 +61,28 @@ func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) {
func (g *GeoLocationCheck) Name() string {
return GeoLocationCheckName
}
func (g *GeoLocationCheck) Validate() error {
if g.Action == "" {
return fmt.Errorf("%s action shouldn't be empty", g.Name())
}
allowedActions := []string{CheckActionAllow, CheckActionDeny}
if !slices.Contains(allowedActions, g.Action) {
return fmt.Errorf("%s action is not valid", g.Name())
}
if len(g.Locations) == 0 {
return fmt.Errorf("%s locations shouldn't be empty", g.Name())
}
for _, loc := range g.Locations {
if loc.CountryCode == "" {
return fmt.Errorf("%s country code shouldn't be empty", g.Name())
}
if !countryCodeRegex.MatchString(loc.CountryCode) {
return fmt.Errorf("%s country code must be 2 letters (ISO 3166-1 alpha-2 format)", g.Name())
}
}
return nil
}

View File

@@ -236,3 +236,81 @@ func TestGeoLocationCheck_Check(t *testing.T) {
})
}
}
func TestGeoLocationCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check GeoLocationCheck
expectedError bool
}{
{
name: "Valid location list",
check: GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
},
expectedError: false,
},
{
name: "Invalid empty location list",
check: GeoLocationCheck{
Action: CheckActionDeny,
Locations: []Location{},
},
expectedError: true,
},
{
name: "Invalid empty country name",
check: GeoLocationCheck{
Action: CheckActionDeny,
Locations: []Location{
{
CityName: "Los Angeles",
},
},
},
expectedError: true,
},
{
name: "Invalid check action",
check: GeoLocationCheck{
Action: "unknownAction",
Locations: []Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
},
expectedError: true,
},
{
name: "Invalid country code",
check: GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{
{
CountryCode: "USA",
},
},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,6 +1,8 @@
package posture
import (
"fmt"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
@@ -37,3 +39,13 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
func (n *NBVersionCheck) Name() string {
return NBVersionCheckName
}
func (n *NBVersionCheck) Validate() error {
if n.MinVersion == "" {
return fmt.Errorf("%s minimum version shouldn't be empty", n.Name())
}
if !isVersionValid(n.MinVersion) {
return fmt.Errorf("%s version: %s is not valid", n.Name(), n.MinVersion)
}
return nil
}

View File

@@ -108,3 +108,33 @@ func TestNBVersionCheck_Check(t *testing.T) {
})
}
}
func TestNBVersionCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check NBVersionCheck
expectedError bool
}{
{
name: "Valid NBVersionCheck",
check: NBVersionCheck{MinVersion: "1.0"},
expectedError: false,
},
{
name: "Invalid NBVersionCheck",
check: NBVersionCheck{},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
type PeerNetworkRangeCheck struct {
@@ -52,3 +53,19 @@ func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) {
func (p *PeerNetworkRangeCheck) Name() string {
return PeerNetworkRangeCheckName
}
func (p *PeerNetworkRangeCheck) Validate() error {
if p.Action == "" {
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
}
allowedActions := []string{CheckActionAllow, CheckActionDeny}
if !slices.Contains(allowedActions, p.Action) {
return fmt.Errorf("%s action is not valid", p.Name())
}
if len(p.Ranges) == 0 {
return fmt.Errorf("%s network ranges shouldn't be empty", p.Name())
}
return nil
}

View File

@@ -147,3 +147,52 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) {
})
}
}
func TestNetworkCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check PeerNetworkRangeCheck
expectedError bool
}{
{
name: "Valid network range",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
expectedError: false,
},
{
name: "Invalid empty network range",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{},
},
expectedError: true,
},
{
name: "Invalid check action",
check: PeerNetworkRangeCheck{
Action: "unknownAction",
Ranges: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,11 +1,13 @@
package posture
import (
"fmt"
"strings"
"github.com/hashicorp/go-version"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
type MinVersionCheck struct {
@@ -48,6 +50,35 @@ func (c *OSVersionCheck) Name() string {
return OSVersionCheckName
}
func (c *OSVersionCheck) Validate() error {
if c.Android == nil && c.Darwin == nil && c.Ios == nil && c.Linux == nil && c.Windows == nil {
return fmt.Errorf("%s at least one OS version check is required", c.Name())
}
if c.Android != nil && !isVersionValid(c.Android.MinVersion) {
return fmt.Errorf("%s android version: %s is not valid", c.Name(), c.Android.MinVersion)
}
if c.Ios != nil && !isVersionValid(c.Ios.MinVersion) {
return fmt.Errorf("%s ios version: %s is not valid", c.Name(), c.Ios.MinVersion)
}
if c.Darwin != nil && !isVersionValid(c.Darwin.MinVersion) {
return fmt.Errorf("%s darwin version: %s is not valid", c.Name(), c.Darwin.MinVersion)
}
if c.Linux != nil && !isVersionValid(c.Linux.MinKernelVersion) {
return fmt.Errorf("%s linux kernel version: %s is not valid", c.Name(),
c.Linux.MinKernelVersion)
}
if c.Windows != nil && !isVersionValid(c.Windows.MinKernelVersion) {
return fmt.Errorf("%s windows kernel version: %s is not valid", c.Name(),
c.Windows.MinKernelVersion)
}
return nil
}
func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil {
log.Debugf("peer %s OS is not allowed in the check", peerGoOS)

View File

@@ -150,3 +150,79 @@ func TestOSVersionCheck_Check(t *testing.T) {
})
}
}
func TestOSVersionCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check OSVersionCheck
expectedError bool
}{
{
name: "Valid linux kernel version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"},
},
expectedError: false,
},
{
name: "Valid linux and darwin version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"},
Darwin: &MinVersionCheck{MinVersion: "14.2"},
},
expectedError: false,
},
{
name: "Invalid empty check",
check: OSVersionCheck{},
expectedError: true,
},
{
name: "Invalid empty linux kernel version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{},
},
expectedError: true,
},
{
name: "Invalid empty linux kernel version with correct darwin version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{},
Darwin: &MinVersionCheck{MinVersion: "14.2"},
},
expectedError: true,
},
{
name: "Valid windows kernel version",
check: OSVersionCheck{
Windows: &MinKernelVersionCheck{MinKernelVersion: "10.0"},
},
expectedError: false,
},
{
name: "Valid ios minimum version",
check: OSVersionCheck{
Ios: &MinVersionCheck{MinVersion: "13.0"},
},
expectedError: false,
},
{
name: "Invalid empty window version with valid ios minimum version",
check: OSVersionCheck{
Windows: &MinKernelVersionCheck{},
Ios: &MinVersionCheck{MinVersion: "13.0"},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,79 @@
package posture
import (
"fmt"
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
type Process struct {
LinuxPath string
MacPath string
WindowsPath string
}
type ProcessCheck struct {
Processes []Process
}
var _ Check = (*ProcessCheck)(nil)
func (p *ProcessCheck) Check(peer nbpeer.Peer) (bool, error) {
peerActiveProcesses := extractPeerActiveProcesses(peer.Meta.Files)
var pathSelector func(Process) string
switch peer.Meta.GoOS {
case "linux":
pathSelector = func(process Process) string { return process.LinuxPath }
case "darwin":
pathSelector = func(process Process) string { return process.MacPath }
case "windows":
pathSelector = func(process Process) string { return process.WindowsPath }
default:
return false, fmt.Errorf("unsupported peer's operating system: %s", peer.Meta.GoOS)
}
return p.areAllProcessesRunning(peerActiveProcesses, pathSelector), nil
}
func (p *ProcessCheck) Name() string {
return ProcessCheckName
}
func (p *ProcessCheck) Validate() error {
if len(p.Processes) == 0 {
return fmt.Errorf("%s processes shouldn't be empty", p.Name())
}
for _, process := range p.Processes {
if process.LinuxPath == "" && process.MacPath == "" && process.WindowsPath == "" {
return fmt.Errorf("%s path shouldn't be empty", p.Name())
}
}
return nil
}
// areAllProcessesRunning checks if all processes specified in ProcessCheck are running.
// It uses the provided pathSelector to get the appropriate process path for the peer's OS.
// It returns true if all processes are running, otherwise false.
func (p *ProcessCheck) areAllProcessesRunning(activeProcesses []string, pathSelector func(Process) string) bool {
for _, process := range p.Processes {
path := pathSelector(process)
if path == "" || !slices.Contains(activeProcesses, path) {
return false
}
}
return true
}
// extractPeerActiveProcesses extracts the paths of running processes from the peer meta.
func extractPeerActiveProcesses(files []nbpeer.File) []string {
activeProcesses := make([]string, 0, len(files))
for _, file := range files {
if file.ProcessIsRunning {
activeProcesses = append(activeProcesses, file.Path)
}
}
return activeProcesses
}

View File

@@ -0,0 +1,318 @@
package posture
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/peer"
)
func TestProcessCheck_Check(t *testing.T) {
tests := []struct {
name string
input peer.Peer
check ProcessCheck
wantErr bool
isValid bool
}{
{
name: "darwin with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
Files: []peer.File{
{Path: "/Applications/process1.app", ProcessIsRunning: true},
{Path: "/Applications/process2.app", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{MacPath: "/Applications/process1.app"},
{MacPath: "/Applications/process2.app"},
},
},
wantErr: false,
isValid: true,
},
{
name: "darwin with windows process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
Files: []peer.File{
{Path: "/Applications/process1.app", ProcessIsRunning: true},
{Path: "/Applications/process2.app", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{LinuxPath: "/usr/bin/process1"},
{LinuxPath: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: true,
},
{
name: "linux with matching no running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2", ProcessIsRunning: false},
},
},
},
check: ProcessCheck{
Processes: []Process{
{LinuxPath: "/usr/bin/process1"},
{LinuxPath: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with windows process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with non-matching processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process3"},
{Path: "/usr/bin/process4"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{LinuxPath: "/usr/bin/process1"},
{LinuxPath: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: false,
},
{
name: "windows with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true},
{Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process1.exe"},
},
},
wantErr: false,
isValid: true,
},
{
name: "windows with darwin process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process1.exe"},
{Path: "C:\\Program Files\\process1.exe"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{MacPath: "/Applications/process1.app"},
{LinuxPath: "/Applications/process2.app"},
},
},
wantErr: false,
isValid: false,
},
{
name: "windows with non-matching processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process3.exe"},
{Path: "C:\\Program Files\\process4.exe"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "unsupported ios operating system",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "ios",
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{MacPath: "/Applications/process2.app"},
},
},
wantErr: true,
isValid: false,
},
{
name: "unsupported android operating system",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "android",
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{MacPath: "/Applications/process2.app"},
{LinuxPath: "/usr/bin/process2"},
},
},
wantErr: true,
isValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid, err := tt.check.Check(tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.isValid, isValid)
})
}
}
func TestProcessCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check ProcessCheck
expectedError bool
}{
{
name: "Valid linux, mac and windows processes",
check: ProcessCheck{
Processes: []Process{
{
LinuxPath: "/usr/local/bin/netbird",
MacPath: "/usr/local/bin/netbird",
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
expectedError: false,
},
{
name: "Valid linux process",
check: ProcessCheck{
Processes: []Process{
{
LinuxPath: "/usr/local/bin/netbird",
},
},
},
expectedError: false,
},
{
name: "Valid mac process",
check: ProcessCheck{
Processes: []Process{
{
MacPath: "/Applications/NetBird.app/Contents/MacOS/netbird",
},
},
},
expectedError: false,
},
{
name: "Valid windows process",
check: ProcessCheck{
Processes: []Process{
{
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
expectedError: false,
},
{
name: "Invalid empty processes",
check: ProcessCheck{
Processes: []Process{},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,9 +1,17 @@
package server
import (
"slices"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
const (
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
)
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
@@ -21,7 +29,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us
}
if !user.HasAdminPower() {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks")
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
for _, postureChecks := range account.PostureChecks {
@@ -48,11 +56,11 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks")
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.BadRequest, err.Error())
return status.Errorf(status.InvalidArgument, err.Error())
}
exists, uniqName := am.savePostureChecks(account, postureChecks)
@@ -95,7 +103,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID,
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks")
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
@@ -127,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]
}
if !user.HasAdminPower() {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks")
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
return account.PostureChecks, nil
@@ -176,3 +184,74 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh
return postureChecks, nil
}
// GetPeerAppliedPostureChecks returns posture checks that are applied to the peer.
func (am *DefaultAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
account, err := am.Store.GetAccountByPeerPubKey(peerKey)
if err != nil {
log.Errorf("failed while getting peer %s: %v", peerKey, err)
return nil, err
}
peer, err := account.FindPeerByPubKey(peerKey)
if err != nil {
return nil, status.Errorf(status.NotFound, "peer is not registered")
}
if peer == nil {
return nil, nil
}
peerPostureChecks := am.collectPeerPostureChecks(account, peer)
postureChecksList := make([]posture.Checks, 0, len(peerPostureChecks))
for _, check := range peerPostureChecks {
postureChecksList = append(postureChecksList, check)
}
return postureChecksList, nil
}
// collectPeerPostureChecks collects the posture checks applied for a given peer.
func (am *DefaultAccountManager) collectPeerPostureChecks(account *Account, peer *nbpeer.Peer) map[string]posture.Checks {
peerPostureChecks := make(map[string]posture.Checks)
for _, policy := range account.Policies {
if !policy.Enabled {
continue
}
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
addPolicyPostureChecks(account, policy, peerPostureChecks)
}
}
return peerPostureChecks
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup]
if ok && slices.Contains(group.Peers, peerID) {
return true
}
}
}
return false
}
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureCheck := range account.PostureChecks {
if postureCheck.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureCheck
}
}
}
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
@@ -39,10 +40,10 @@ func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, us
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error {
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
// routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefix(prefix)
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
// lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool)
@@ -114,7 +115,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
}
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
func (am *DefaultAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
@@ -123,6 +124,18 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
return nil, err
}
if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
if len(domains) == 0 && !prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix")
}
if len(domains) > 0 {
prefix = getPlaceholderIP()
}
if peerID != "" && len(peerGroupIDs) != 0 {
return nil, status.Errorf(
status.InvalidArgument,
@@ -133,11 +146,6 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
var newRoute route.Route
newRoute.ID = route.ID(xid.New().String())
prefixType, newPrefix, err := route.ParseNetwork(network)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network)
}
if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups)
if err != nil {
@@ -145,7 +153,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
}
}
err = am.checkRoutePrefixExistsForPeers(account, peerID, newRoute.ID, peerGroupIDs, newPrefix)
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
if err != nil {
return nil, err
}
@@ -165,14 +173,16 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
newRoute.Network = newPrefix
newRoute.NetworkType = prefixType
newRoute.Network = prefix
newRoute.Domains = domains
newRoute.NetworkType = networkType
newRoute.Description = description
newRoute.NetID = netID
newRoute.Masquerade = masquerade
newRoute.Metric = metric
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
@@ -201,10 +211,6 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return status.Errorf(status.InvalidArgument, "route provided is nil")
}
if !routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
}
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
@@ -218,6 +224,18 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return err
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
if len(routeToSave.Domains) == 0 && !routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "invalid Prefix")
}
if len(routeToSave.Domains) > 0 {
routeToSave.Network = getPlaceholderIP()
}
if routeToSave.Peer != "" && len(routeToSave.PeerGroups) != 0 {
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
}
@@ -229,7 +247,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
}
}
err = am.checkRoutePrefixExistsForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network)
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err != nil {
return err
}
@@ -313,10 +331,12 @@ func toProtocolRoute(route *route.Route) *proto.Route {
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
}
}
@@ -327,3 +347,9 @@ func toProtocolRoutes(routes []*route.Route) []*proto.Route {
}
return protoRoutes
}
// getPlaceholderIP returns a placeholder IP address for the route if domains are used
func getPlaceholderIP() netip.Prefix {
// Using an IP from the documentation range to minimize impact in case older clients try to set a route
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -33,13 +34,18 @@ const (
routeGroupHA2 = "routeGroupHA2"
routeInvalidGroup1 = "routeInvalidGroup1"
userID = "testingUser"
existingNetwork = "10.10.10.0/24"
existingRouteID = "random-id"
)
var existingNetwork = netip.MustParsePrefix("10.10.10.0/24")
var existingDomains = domain.List{"example.com"}
func TestCreateRoute(t *testing.T) {
type input struct {
network string
network netip.Prefix
domains domain.List
keepRoute bool
networkType route.NetworkType
netID route.NetID
peerKey string
peerGroupIDs []string
@@ -59,9 +65,10 @@ func TestCreateRoute(t *testing.T) {
expectedRoute *route.Route
}{
{
name: "Happy Path",
name: "Happy Path Network",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "happy",
peerKey: peer1ID,
description: "super",
@@ -84,10 +91,41 @@ func TestCreateRoute(t *testing.T) {
Groups: []string{routeGroup1},
},
},
{
name: "Happy Path Domains",
inputArgs: input{
domains: domain.List{"domain1", "domain2"},
keepRoute: true,
networkType: route.DomainNetwork,
netID: "happy",
peerKey: peer1ID,
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
Network: netip.MustParsePrefix("192.0.2.0/32"),
Domains: domain.List{"domain1", "domain2"},
NetworkType: route.DomainNetwork,
NetID: "happy",
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
KeepRoute: true,
},
},
{
name: "Happy Path Peer Groups",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "happy",
peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
description: "super",
@@ -111,9 +149,10 @@ func TestCreateRoute(t *testing.T) {
},
},
{
name: "Both peer and peer_groups Provided Should Fail",
name: "Both network and domains provided should fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
domains: domain.List{"domain1", "domain2"},
netID: "happy",
peerKey: peer1ID,
peerGroupIDs: []string{routeGroupHA1},
@@ -127,16 +166,18 @@ func TestCreateRoute(t *testing.T) {
shouldCreate: false,
},
{
name: "Bad Prefix Should Fail",
name: "Both peer and peer_groups Provided Should Fail",
inputArgs: input{
network: "192.168.0.0/34",
netID: "happy",
peerKey: peer1ID,
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1},
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "happy",
peerKey: peer1ID,
peerGroupIDs: []string{routeGroupHA1},
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1},
},
errFunc: require.Error,
shouldCreate: false,
@@ -144,7 +185,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Bad Peer Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "happy",
peerKey: "notExistingPeer",
description: "super",
@@ -157,9 +199,10 @@ func TestCreateRoute(t *testing.T) {
shouldCreate: false,
},
{
name: "Bad Peer already has this route",
name: "Bad Peer already has this network route",
inputArgs: input{
network: existingNetwork,
networkType: route.IPv4Network,
netID: "bad",
peerKey: peer5ID,
description: "super",
@@ -173,9 +216,44 @@ func TestCreateRoute(t *testing.T) {
shouldCreate: false,
},
{
name: "Bad Peers Group already has this route",
name: "Bad Peer already has this domains route",
inputArgs: input{
domains: existingDomains,
networkType: route.DomainNetwork,
netID: "bad",
peerKey: peer5ID,
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1},
},
createInitRoute: true,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Bad Peers Group already has this network route",
inputArgs: input{
network: existingNetwork,
networkType: route.IPv4Network,
netID: "bad",
peerGroupIDs: []string{routeGroup1, routeGroup3},
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1},
},
createInitRoute: true,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Bad Peers Group already has this domains route",
inputArgs: input{
domains: existingDomains,
networkType: route.DomainNetwork,
netID: "bad",
peerGroupIDs: []string{routeGroup1, routeGroup3},
description: "super",
@@ -191,7 +269,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Empty Peer Should Create",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "happy",
peerKey: "",
description: "super",
@@ -217,7 +296,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Large Metric Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
peerKey: peer1ID,
netID: "happy",
description: "super",
@@ -232,7 +312,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Small Metric Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "happy",
peerKey: peer1ID,
description: "super",
@@ -247,7 +328,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Large NetID Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
peerKey: peer1ID,
netID: "12345678901234567890qwertyuiopqwertyuiop1",
description: "super",
@@ -262,7 +344,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Small NetID Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "",
peerKey: peer1ID,
description: "",
@@ -277,7 +360,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Empty Group List Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "NewId",
peerKey: peer1ID,
description: "",
@@ -292,7 +376,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Empty Group ID string Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "NewId",
peerKey: peer1ID,
description: "",
@@ -307,7 +392,8 @@ func TestCreateRoute(t *testing.T) {
{
name: "Invalid Group Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
network: netip.MustParsePrefix("192.168.0.0/16"),
networkType: route.IPv4Network,
netID: "NewId",
peerKey: peer1ID,
description: "",
@@ -334,29 +420,14 @@ func TestCreateRoute(t *testing.T) {
if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll()
if errInit != nil {
t.Errorf("failed to get group all: %s", errInit)
}
_, errInit = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4},
"", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID)
if errInit != nil {
t.Errorf("failed to create init route: %s", errInit)
}
require.NoError(t, errInit)
_, errInit = am.CreateRoute(account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit)
_, errInit = am.CreateRoute(account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit)
}
outRoute, err := am.CreateRoute(
account.Id,
testCase.inputArgs.network,
testCase.inputArgs.peerKey,
testCase.inputArgs.peerGroupIDs,
testCase.inputArgs.description,
testCase.inputArgs.netID,
testCase.inputArgs.masquerade,
testCase.inputArgs.metric,
testCase.inputArgs.groups,
testCase.inputArgs.enabled,
userID,
)
outRoute, err := am.CreateRoute(account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
testCase.errFunc(t, err)
@@ -379,8 +450,13 @@ func TestSaveRoute(t *testing.T) {
validUsedPeer := peer5ID
invalidPeer := "nonExisting"
validPrefix := netip.MustParsePrefix("192.168.0.0/24")
placeholderPrefix := netip.MustParsePrefix("192.0.2.0/32")
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
validMetric := 1000
trueKeepRoute := true
falseKeepRoute := false
ipv4networkType := route.IPv4Network
domainNetworkType := route.DomainNetwork
invalidMetric := 99999
validNetID := route.NetID("12345678901234567890qw")
invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1")
@@ -395,6 +471,9 @@ func TestSaveRoute(t *testing.T) {
newPeerGroups []string
newMetric *int
newPrefix *netip.Prefix
newDomains domain.List
newNetworkType *route.NetworkType
newKeepRoute *bool
newGroups []string
skipCopying bool
shouldCreate bool
@@ -402,7 +481,7 @@ func TestSaveRoute(t *testing.T) {
expectedRoute *route.Route
}{
{
name: "Happy Path",
name: "Happy Path Network",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix("192.168.0.0/16"),
@@ -434,6 +513,45 @@ func TestSaveRoute(t *testing.T) {
Groups: []string{routeGroup2},
},
},
{
name: "Happy Path Domains",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.Prefix{},
Domains: domain.List{"example.com"},
KeepRoute: false,
NetID: validNetID,
NetworkType: route.DomainNetwork,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPeer: &validPeer,
newMetric: &validMetric,
newPrefix: &netip.Prefix{},
newDomains: domain.List{"example.com", "example2.com"},
newKeepRoute: &trueKeepRoute,
newGroups: []string{routeGroup1},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: "testingRoute",
Network: placeholderPrefix,
Domains: domain.List{"example.com", "example2.com"},
KeepRoute: true,
NetID: validNetID,
NetworkType: route.DomainNetwork,
Peer: validPeer,
Description: "super",
Masquerade: false,
Metric: validMetric,
Enabled: true,
Groups: []string{routeGroup1},
},
},
{
name: "Happy Path Peer Groups",
existingRoute: &route.Route{
@@ -466,6 +584,23 @@ func TestSaveRoute(t *testing.T) {
Groups: []string{routeGroup2},
},
},
{
name: "Both network and domains provided should fail",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: validNetID,
NetworkType: route.IPv4Network,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPrefix: &validPrefix,
newDomains: domain.List{"example.com"},
errFunc: require.Error,
},
{
name: "Both peer and peers_roup Provided Should Fail",
existingRoute: &route.Route{
@@ -623,7 +758,7 @@ func TestSaveRoute(t *testing.T) {
name: "Allow to modify existing route with new peer",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix(existingNetwork),
Network: existingNetwork,
NetID: validNetID,
NetworkType: route.IPv4Network,
Peer: peer1ID,
@@ -638,7 +773,7 @@ func TestSaveRoute(t *testing.T) {
shouldCreate: true,
expectedRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix(existingNetwork),
Network: existingNetwork,
NetID: validNetID,
NetworkType: route.IPv4Network,
Peer: validPeer,
@@ -654,7 +789,7 @@ func TestSaveRoute(t *testing.T) {
name: "Do not allow to modify existing route with a peer from another route",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix(existingNetwork),
Network: existingNetwork,
NetID: validNetID,
NetworkType: route.IPv4Network,
Peer: peer1ID,
@@ -672,7 +807,7 @@ func TestSaveRoute(t *testing.T) {
name: "Do not allow to modify existing route with a peers group from another route",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix(existingNetwork),
Network: existingNetwork,
NetID: validNetID,
NetworkType: route.IPv4Network,
PeerGroups: []string{routeGroup3},
@@ -686,6 +821,80 @@ func TestSaveRoute(t *testing.T) {
newPeerGroups: []string{routeGroup4},
errFunc: require.Error,
},
{
name: "Allow switching from network route to domains route",
existingRoute: &route.Route{
ID: "testingRoute",
Network: validPrefix,
Domains: nil,
KeepRoute: false,
NetID: validNetID,
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPrefix: &netip.Prefix{},
newDomains: domain.List{"example.com"},
newNetworkType: &domainNetworkType,
newKeepRoute: &trueKeepRoute,
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: "testingRoute",
Network: placeholderPrefix,
NetworkType: route.DomainNetwork,
Domains: domain.List{"example.com"},
KeepRoute: true,
NetID: validNetID,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
},
{
name: "Allow switching from domains route to network route",
existingRoute: &route.Route{
ID: "testingRoute",
Network: placeholderPrefix,
Domains: domain.List{"example.com"},
KeepRoute: true,
NetID: validNetID,
NetworkType: route.DomainNetwork,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPrefix: &validPrefix,
newDomains: nil,
newKeepRoute: &falseKeepRoute,
newNetworkType: &ipv4networkType,
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: "testingRoute",
Network: validPrefix,
NetworkType: route.IPv4Network,
KeepRoute: false,
Domains: nil,
NetID: validNetID,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
@@ -702,7 +911,7 @@ func TestSaveRoute(t *testing.T) {
if testCase.createInitRoute {
account.Routes["initRoute"] = &route.Route{
ID: "initRoute",
Network: netip.MustParsePrefix(existingNetwork),
Network: existingNetwork,
NetID: existingRouteID,
NetworkType: route.IPv4Network,
PeerGroups: []string{routeGroup4},
@@ -739,6 +948,16 @@ func TestSaveRoute(t *testing.T) {
routeToSave.Network = *testCase.newPrefix
}
routeToSave.Domains = testCase.newDomains
if testCase.newNetworkType != nil {
routeToSave.NetworkType = *testCase.newNetworkType
}
if testCase.newKeepRoute != nil {
routeToSave.KeepRoute = *testCase.newKeepRoute
}
if testCase.newGroups != nil {
routeToSave.Groups = testCase.newGroups
}
@@ -771,6 +990,8 @@ func TestDeleteRoute(t *testing.T) {
testingRoute := &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix("192.168.0.0/16"),
Domains: domain.List{"domain1", "domain2"},
KeepRoute: true,
NetworkType: route.IPv4Network,
Peer: peer1Key,
Description: "super",
@@ -839,9 +1060,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
newRoute, err := am.CreateRoute(
account.Id, baseRoute.Network.String(), baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description,
baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID)
newRoute, err := am.CreateRoute(account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true)
@@ -932,9 +1151,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network.String(), peer1ID, []string{},
baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false,
userID)
createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err)
noDisabledRoutes, err := am.GetNetworkMap(peer1ID)