mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Release 0.28.0 (#2092)
* compile client under freebsd (#1620) Compile netbird client under freebsd and now support netstack and userspace modes. Refactoring linux specific code to share same code with FreeBSD, move to *_unix.go files. Not implemented yet: Kernel mode not supported DNS probably does not work yet Routing also probably does not work yet SSH support did not tested yet Lack of test environment for freebsd (dedicated VM for github runners under FreeBSD required) Lack of tests for freebsd specific code info reporting need to review and also implement, for example OS reported as GENERIC instead of FreeBSD (lack of FreeBSD icon in management interface) Lack of proper client setup under FreeBSD Lack of FreeBSD port/package * Add DNS routes (#1943) Given domains are resolved periodically and resolved IPs are replaced with the new ones. Unless the flag keep_route is set to true, then only new ones are added. This option is helpful if there are long-running connections that might still point to old IP addresses from changed DNS records. * Add process posture check (#1693) Introduces a process posture check to validate the existence and active status of specific binaries on peer systems. The check ensures that files are present at specified paths, and that corresponding processes are running. This check supports Linux, Windows, and macOS systems. Co-authored-by: Evgenii <mail@skillcoder.com> Co-authored-by: Pascal Fischer <pascal@netbird.io> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com> Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -12,12 +12,13 @@ import (
|
||||
|
||||
type Client interface {
|
||||
io.Closer
|
||||
Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
GetServerPublicKey() (*wgtypes.Key, error)
|
||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
GetNetworkMap() (*proto.NetworkMap, error)
|
||||
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
||||
IsHealthy() bool
|
||||
SyncMeta(sysInfo *system.Info) error
|
||||
}
|
||||
|
||||
@@ -257,7 +257,7 @@ func TestClient_Sync(t *testing.T) {
|
||||
ch := make(chan *mgmtProto.SyncResponse, 1)
|
||||
|
||||
go func() {
|
||||
err = client.Sync(context.Background(), func(msg *mgmtProto.SyncResponse) error {
|
||||
err = client.Sync(context.Background(), info, func(msg *mgmtProto.SyncResponse) error {
|
||||
ch <- msg
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -29,6 +29,11 @@ import (
|
||||
|
||||
const ConnectTimeout = 10 * time.Second
|
||||
|
||||
const (
|
||||
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
|
||||
errMsgNoMgmtConnection = "no connection to management"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorders
|
||||
type ConnStateNotifier interface {
|
||||
MarkManagementDisconnected(error)
|
||||
@@ -113,13 +118,11 @@ func (c *GrpcClient) ready() bool {
|
||||
|
||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||
// Blocking request. The result will be sent via msgHandler callback function
|
||||
func (c *GrpcClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
backOff := defaultBackoff(ctx)
|
||||
|
||||
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
operation := func() error {
|
||||
log.Debugf("management connection state %v", c.conn.GetState())
|
||||
|
||||
connState := c.conn.GetState()
|
||||
|
||||
if connState == connectivity.Shutdown {
|
||||
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
|
||||
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
||||
@@ -129,55 +132,60 @@ func (c *GrpcClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncRe
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf("failed getting Management Service public key: %s", err)
|
||||
log.Debugf(errMsgMgmtPublicKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancelStream := context.WithCancel(ctx)
|
||||
defer cancelStream()
|
||||
stream, err := c.connectToStream(ctx, *serverPubKey)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open Management Service stream: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("connected to the Management Service stream")
|
||||
c.notifyConnected()
|
||||
// blocking until error
|
||||
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
s, _ := gstatus.FromError(err)
|
||||
switch s.Code() {
|
||||
case codes.PermissionDenied:
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
case codes.Canceled:
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
default:
|
||||
backOff.Reset() // reset backoff counter after successful connection
|
||||
c.notifyDisconnected(err)
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backOff)
|
||||
err := backoff.Retry(operation, defaultBackoff(ctx))
|
||||
if err != nil {
|
||||
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
|
||||
msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
ctx, cancelStream := context.WithCancel(ctx)
|
||||
defer cancelStream()
|
||||
|
||||
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open Management Service stream: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("connected to the Management Service stream")
|
||||
c.notifyConnected()
|
||||
|
||||
// blocking until error
|
||||
err = c.receiveEvents(stream, serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
s, _ := gstatus.FromError(err)
|
||||
switch s.Code() {
|
||||
case codes.PermissionDenied:
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
case codes.Canceled:
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
default:
|
||||
c.notifyDisconnected(err)
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNetworkMap return with the network map
|
||||
func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
|
||||
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf("failed getting Management Service public key: %s", err)
|
||||
@@ -186,7 +194,7 @@ func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
|
||||
|
||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||
defer cancelStream()
|
||||
stream, err := c.connectToStream(ctx, *serverPubKey)
|
||||
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open Management Service stream: %s", err)
|
||||
return nil, err
|
||||
@@ -219,8 +227,8 @@ func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
|
||||
return decryptedResp.GetNetworkMap(), nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
||||
req := &proto.SyncRequest{}
|
||||
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
|
||||
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
|
||||
|
||||
myPrivateKey := c.key
|
||||
myPublicKey := myPrivateKey.PublicKey()
|
||||
@@ -269,7 +277,7 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
|
||||
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
|
||||
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management")
|
||||
return nil, fmt.Errorf(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
|
||||
@@ -316,7 +324,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
|
||||
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management")
|
||||
return nil, fmt.Errorf(errMsgNoMgmtConnection)
|
||||
}
|
||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||
if err != nil {
|
||||
@@ -431,6 +439,35 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
|
||||
return flowInfoResp, nil
|
||||
}
|
||||
|
||||
// SyncMeta sends updated system metadata to the Management Service.
|
||||
// It should be used if there is changes on peer posture check after initial sync.
|
||||
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
||||
if !c.ready() {
|
||||
return fmt.Errorf(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf(errMsgMgmtPublicKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
syncMetaReq, err := encryption.EncryptMessage(*serverPubKey, c.key, &proto.SyncMetaRequest{Meta: infoToMetaData(sysInfo)})
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt message: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: syncMetaReq,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyDisconnected(err error) {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
@@ -464,6 +501,15 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||
})
|
||||
}
|
||||
|
||||
files := make([]*proto.File, 0, len(info.Files))
|
||||
for _, file := range info.Files {
|
||||
files = append(files, &proto.File{
|
||||
Path: file.Path,
|
||||
Exist: file.Exist,
|
||||
ProcessIsRunning: file.ProcessIsRunning,
|
||||
})
|
||||
}
|
||||
|
||||
return &proto.PeerSystemMeta{
|
||||
Hostname: info.Hostname,
|
||||
GoOS: info.GoOS,
|
||||
@@ -483,5 +529,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||
Cloud: info.Environment.Cloud,
|
||||
Platform: info.Environment.Platform,
|
||||
},
|
||||
Files: files,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,12 +11,13 @@ import (
|
||||
|
||||
type MockClient struct {
|
||||
CloseFunc func() error
|
||||
SyncFunc func(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
SyncMetaFunc func(sysInfo *system.Info) error
|
||||
}
|
||||
|
||||
func (m *MockClient) IsHealthy() bool {
|
||||
@@ -30,11 +31,11 @@ func (m *MockClient) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
if m.SyncFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.SyncFunc(ctx, msgHandler)
|
||||
return m.SyncFunc(ctx, sysInfo, msgHandler)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
@@ -73,6 +74,13 @@ func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
|
||||
}
|
||||
|
||||
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
|
||||
func (m *MockClient) GetNetworkMap() (*proto.NetworkMap, error) {
|
||||
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
|
||||
if m.SyncMetaFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.SyncMetaFunc(sysInfo)
|
||||
}
|
||||
|
||||
34
management/domain/domain.go
Normal file
34
management/domain/domain.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
type Domain string
|
||||
|
||||
// String converts the Domain to a non-punycode string.
|
||||
func (d Domain) String() (string, error) {
|
||||
unicode, err := idna.ToUnicode(string(d))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return unicode, nil
|
||||
}
|
||||
|
||||
// SafeString converts the Domain to a non-punycode string, falling back to the original string if conversion fails.
|
||||
func (d Domain) SafeString() string {
|
||||
str, err := d.String()
|
||||
if err != nil {
|
||||
str = string(d)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// FromString creates a Domain from a string, converting it to punycode.
|
||||
func FromString(s string) (Domain, error) {
|
||||
ascii, err := idna.ToASCII(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return Domain(ascii), nil
|
||||
}
|
||||
83
management/domain/list.go
Normal file
83
management/domain/list.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package domain
|
||||
|
||||
import "strings"
|
||||
|
||||
type List []Domain
|
||||
|
||||
// ToStringList converts a List to a slice of string.
|
||||
func (d List) ToStringList() ([]string, error) {
|
||||
var list []string
|
||||
for _, domain := range d {
|
||||
s, err := domain.String()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, s)
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// ToPunycodeList converts the List to a slice of Punycode-encoded domain strings.
|
||||
func (d List) ToPunycodeList() []string {
|
||||
var list []string
|
||||
for _, domain := range d {
|
||||
list = append(list, string(domain))
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// ToSafeStringList converts the List to a slice of non-punycode strings.
|
||||
// If a domain cannot be converted, the original string is used.
|
||||
func (d List) ToSafeStringList() []string {
|
||||
var list []string
|
||||
for _, domain := range d {
|
||||
list = append(list, domain.SafeString())
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// String converts List to a comma-separated string.
|
||||
func (d List) String() (string, error) {
|
||||
list, err := d.ToStringList()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.Join(list, ", "), nil
|
||||
}
|
||||
|
||||
// SafeString converts List to a comma-separated non-punycode string.
|
||||
// If a domain cannot be converted, the original string is used.
|
||||
func (d List) SafeString() string {
|
||||
str, err := d.String()
|
||||
if err != nil {
|
||||
return strings.Join(d.ToPunycodeList(), ", ")
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// PunycodeString converts the List to a comma-separated string of Punycode-encoded domains.
|
||||
func (d List) PunycodeString() string {
|
||||
return strings.Join(d.ToPunycodeList(), ", ")
|
||||
}
|
||||
|
||||
// FromStringList creates a DomainList from a slice of string.
|
||||
func FromStringList(s []string) (List, error) {
|
||||
var dl List
|
||||
for _, domain := range s {
|
||||
d, err := FromString(domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl = append(dl, d)
|
||||
}
|
||||
return dl, nil
|
||||
}
|
||||
|
||||
// FromPunycodeList creates a List from a slice of Punycode-encoded domain strings.
|
||||
func FromPunycodeList(s []string) List {
|
||||
var dl List
|
||||
for _, domain := range s {
|
||||
dl = append(dl, Domain(domain))
|
||||
}
|
||||
return dl
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -38,6 +38,12 @@ service ManagementService {
|
||||
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
|
||||
rpc GetPKCEAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
|
||||
|
||||
// SyncMeta is used to sync metadata of the peer.
|
||||
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
|
||||
// sync meta will evaluate the checks and update the peer meta with the result.
|
||||
// EncryptedMessage of the request has a body of Empty.
|
||||
rpc SyncMeta(EncryptedMessage) returns (Empty) {}
|
||||
}
|
||||
|
||||
message EncryptedMessage {
|
||||
@@ -50,7 +56,10 @@ message EncryptedMessage {
|
||||
int32 version = 3;
|
||||
}
|
||||
|
||||
message SyncRequest {}
|
||||
message SyncRequest {
|
||||
// Meta data of the peer
|
||||
PeerSystemMeta meta = 1;
|
||||
}
|
||||
|
||||
// SyncResponse represents a state that should be applied to the local peer (e.g. Wiretrustee servers config as well as local peer and remote peers configs)
|
||||
message SyncResponse {
|
||||
@@ -69,6 +78,14 @@ message SyncResponse {
|
||||
bool remotePeersIsEmpty = 4;
|
||||
|
||||
NetworkMap NetworkMap = 5;
|
||||
|
||||
// Posture checks to be evaluated by client
|
||||
repeated Checks Checks = 6;
|
||||
}
|
||||
|
||||
message SyncMetaRequest {
|
||||
// Meta data of the peer
|
||||
PeerSystemMeta meta = 1;
|
||||
}
|
||||
|
||||
message LoginRequest {
|
||||
@@ -82,6 +99,7 @@ message LoginRequest {
|
||||
PeerKeys peerKeys = 4;
|
||||
|
||||
}
|
||||
|
||||
// PeerKeys is additional peer info like SSH pub key and WireGuard public key.
|
||||
// This message is sent on Login or register requests, or when a key rotation has to happen.
|
||||
message PeerKeys {
|
||||
@@ -100,6 +118,16 @@ message Environment {
|
||||
string platform = 2;
|
||||
}
|
||||
|
||||
// File represents a file on the system.
|
||||
message File {
|
||||
// path is the path to the file.
|
||||
string path = 1;
|
||||
// exist indicate whether the file exists.
|
||||
bool exist = 2;
|
||||
// processIsRunning indicates whether the file is a running process or not.
|
||||
bool processIsRunning = 3;
|
||||
}
|
||||
|
||||
// PeerSystemMeta is machine meta data like OS and version.
|
||||
message PeerSystemMeta {
|
||||
string hostname = 1;
|
||||
@@ -117,6 +145,7 @@ message PeerSystemMeta {
|
||||
string sysProductName = 13;
|
||||
string sysManufacturer = 14;
|
||||
Environment environment = 15;
|
||||
repeated File files = 16;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -124,6 +153,8 @@ message LoginResponse {
|
||||
WiretrusteeConfig wiretrusteeConfig = 1;
|
||||
// Peer local config
|
||||
PeerConfig peerConfig = 2;
|
||||
// Posture checks to be evaluated by client
|
||||
repeated Checks Checks = 3;
|
||||
}
|
||||
|
||||
message ServerKeyResponse {
|
||||
@@ -303,6 +334,8 @@ message Route {
|
||||
int64 Metric = 5;
|
||||
bool Masquerade = 6;
|
||||
string NetID = 7;
|
||||
repeated string Domains = 8;
|
||||
bool keepRoute = 9;
|
||||
}
|
||||
|
||||
// DNSConfig represents a dns.Update
|
||||
@@ -371,3 +404,7 @@ message NetworkAddress {
|
||||
string netIP = 1;
|
||||
string mac = 2;
|
||||
}
|
||||
|
||||
message Checks {
|
||||
repeated string Files= 1;
|
||||
}
|
||||
|
||||
@@ -43,6 +43,11 @@ type ManagementServiceClient interface {
|
||||
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
|
||||
GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
|
||||
// SyncMeta is used to sync metadata of the peer.
|
||||
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
|
||||
// sync meta will evaluate the checks and update the peer meta with the result.
|
||||
// EncryptedMessage of the request has a body of Empty.
|
||||
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
|
||||
}
|
||||
|
||||
type managementServiceClient struct {
|
||||
@@ -130,6 +135,15 @@ func (c *managementServiceClient) GetPKCEAuthorizationFlow(ctx context.Context,
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/SyncMeta", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ManagementServiceServer is the server API for ManagementService service.
|
||||
// All implementations must embed UnimplementedManagementServiceServer
|
||||
// for forward compatibility
|
||||
@@ -159,6 +173,11 @@ type ManagementServiceServer interface {
|
||||
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
|
||||
GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
|
||||
// SyncMeta is used to sync metadata of the peer.
|
||||
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
|
||||
// sync meta will evaluate the checks and update the peer meta with the result.
|
||||
// EncryptedMessage of the request has a body of Empty.
|
||||
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
|
||||
mustEmbedUnimplementedManagementServiceServer()
|
||||
}
|
||||
|
||||
@@ -184,6 +203,9 @@ func (UnimplementedManagementServiceServer) GetDeviceAuthorizationFlow(context.C
|
||||
func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
|
||||
|
||||
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -308,6 +330,24 @@ func _ManagementService_GetPKCEAuthorizationFlow_Handler(srv interface{}, ctx co
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).SyncMeta(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/SyncMeta",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).SyncMeta(ctx, req.(*EncryptedMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -335,6 +375,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetPKCEAuthorizationFlow",
|
||||
Handler: _ManagementService_GetPKCEAuthorizationFlow_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SyncMeta",
|
||||
Handler: _ManagementService_SyncMeta_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
@@ -102,7 +103,7 @@ type AccountManager interface {
|
||||
DeletePolicy(accountID, policyID, userID string) error
|
||||
ListPolicies(accountID, userID string) ([]*Policy, error)
|
||||
GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||
CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
SaveRoute(accountID, userID string, route *route.Route) error
|
||||
DeleteRoute(accountID string, routeID route.ID, userID string) error
|
||||
ListRoutes(accountID, userID string) ([]*route.Route, error)
|
||||
@@ -117,6 +118,7 @@ type AccountManager interface {
|
||||
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
|
||||
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
||||
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error)
|
||||
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
||||
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
@@ -131,8 +133,9 @@ type AccountManager interface {
|
||||
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
|
||||
GroupValidation(accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
|
||||
SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
|
||||
CancelPeerRoutines(peer *nbpeer.Peer) error
|
||||
SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
}
|
||||
|
||||
@@ -275,7 +278,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
|
||||
peerRoutesMembership := make(lookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{}
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
groupListMap := a.getPeerGroups(peerID)
|
||||
@@ -293,7 +296,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
|
||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
_, found := peerMemberships[string(route.GetHAUniqueID(r))]
|
||||
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
||||
if !found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
}
|
||||
@@ -376,11 +379,13 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
// GetRoutesByPrefix return list of routes by account and route prefix
|
||||
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
|
||||
var routes []*route.Route
|
||||
for _, r := range a.Routes {
|
||||
if r.Network.String() == prefix.String() {
|
||||
if r.IsDynamic() && r.Domains.PunycodeString() == domains.PunycodeString() {
|
||||
routes = append(routes, r)
|
||||
} else if r.Network.String() == prefix.String() {
|
||||
routes = append(routes, r)
|
||||
}
|
||||
}
|
||||
@@ -1850,7 +1855,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
@@ -1867,7 +1872,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.I
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
|
||||
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -1906,6 +1911,27 @@ func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountReadLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, err = am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account)
|
||||
if err != nil {
|
||||
return mapError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
||||
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
||||
|
||||
@@ -1435,7 +1435,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.GetRoutesByPrefix(prefix)
|
||||
routes := account.GetRoutesByPrefixOrDomains(prefix, nil)
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -134,7 +135,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
return err
|
||||
}
|
||||
|
||||
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
|
||||
if syncReq.GetMeta() == nil {
|
||||
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
return mapError(err)
|
||||
}
|
||||
@@ -157,12 +162,15 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
// keep a connection to the peer and send updates when available
|
||||
return s.handleUpdates(peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
case update, open := <-updates:
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||
}
|
||||
@@ -174,21 +182,10 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
}
|
||||
log.Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
if err := s.sendUpdate(peerKey, peer, update, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.Debugf("sent an update to peer %s", peerKey.String())
|
||||
// condition when client <-> server connection has been terminated
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
@@ -199,6 +196,26 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
}
|
||||
}
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
@@ -250,14 +267,18 @@ func mapError(err error) error {
|
||||
return status.Errorf(codes.Internal, "failed handling request")
|
||||
}
|
||||
|
||||
func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
|
||||
osVersion := loginReq.GetMeta().GetOSVersion()
|
||||
if osVersion == "" {
|
||||
osVersion = loginReq.GetMeta().GetCore()
|
||||
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
if meta == nil {
|
||||
return nbpeer.PeerSystemMeta{}
|
||||
}
|
||||
|
||||
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(loginReq.GetMeta().GetNetworkAddresses()))
|
||||
for _, addr := range loginReq.GetMeta().GetNetworkAddresses() {
|
||||
osVersion := meta.GetOSVersion()
|
||||
if osVersion == "" {
|
||||
osVersion = meta.GetCore()
|
||||
}
|
||||
|
||||
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(meta.GetNetworkAddresses()))
|
||||
for _, addr := range meta.GetNetworkAddresses() {
|
||||
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
|
||||
@@ -269,24 +290,34 @@ func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
|
||||
})
|
||||
}
|
||||
|
||||
files := make([]nbpeer.File, 0, len(meta.GetFiles()))
|
||||
for _, file := range meta.GetFiles() {
|
||||
files = append(files, nbpeer.File{
|
||||
Path: file.GetPath(),
|
||||
Exist: file.GetExist(),
|
||||
ProcessIsRunning: file.GetProcessIsRunning(),
|
||||
})
|
||||
}
|
||||
|
||||
return nbpeer.PeerSystemMeta{
|
||||
Hostname: loginReq.GetMeta().GetHostname(),
|
||||
GoOS: loginReq.GetMeta().GetGoOS(),
|
||||
Kernel: loginReq.GetMeta().GetKernel(),
|
||||
Platform: loginReq.GetMeta().GetPlatform(),
|
||||
OS: loginReq.GetMeta().GetOS(),
|
||||
Hostname: meta.GetHostname(),
|
||||
GoOS: meta.GetGoOS(),
|
||||
Kernel: meta.GetKernel(),
|
||||
Platform: meta.GetPlatform(),
|
||||
OS: meta.GetOS(),
|
||||
OSVersion: osVersion,
|
||||
WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(),
|
||||
UIVersion: loginReq.GetMeta().GetUiVersion(),
|
||||
KernelVersion: loginReq.GetMeta().GetKernelVersion(),
|
||||
WtVersion: meta.GetWiretrusteeVersion(),
|
||||
UIVersion: meta.GetUiVersion(),
|
||||
KernelVersion: meta.GetKernelVersion(),
|
||||
NetworkAddresses: networkAddresses,
|
||||
SystemSerialNumber: loginReq.GetMeta().GetSysSerialNumber(),
|
||||
SystemProductName: loginReq.GetMeta().GetSysProductName(),
|
||||
SystemManufacturer: loginReq.GetMeta().GetSysManufacturer(),
|
||||
SystemSerialNumber: meta.GetSysSerialNumber(),
|
||||
SystemProductName: meta.GetSysProductName(),
|
||||
SystemManufacturer: meta.GetSysManufacturer(),
|
||||
Environment: nbpeer.Environment{
|
||||
Cloud: loginReq.GetMeta().GetEnvironment().GetCloud(),
|
||||
Platform: loginReq.GetMeta().GetEnvironment().GetPlatform(),
|
||||
Cloud: meta.GetEnvironment().GetCloud(),
|
||||
Platform: meta.GetEnvironment().GetPlatform(),
|
||||
},
|
||||
Files: files,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,24 +366,11 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
userID := ""
|
||||
// JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered,
|
||||
// or it uses a setup key to register.
|
||||
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
for i := 0; i < 3; i++ {
|
||||
userID, err = s.validateToken(loginReq.GetJwtToken())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
|
||||
"Trying again as it may be due to the IdP cache issue", peerKey, err)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userID, err := s.processJwtToken(loginReq, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sshKey []byte
|
||||
if loginReq.GetPeerKeys() != nil {
|
||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||
@@ -361,12 +379,11 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
|
||||
WireGuardPubKey: peerKey.String(),
|
||||
SSHKey: string(sshKey),
|
||||
Meta: extractPeerMeta(loginReq),
|
||||
Meta: extractPeerMeta(loginReq.GetMeta()),
|
||||
UserID: userID,
|
||||
SetupKey: loginReq.GetSetupKey(),
|
||||
ConnectionIP: realIP,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(err)
|
||||
@@ -381,6 +398,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
loginResp := &proto.LoginResponse{
|
||||
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
|
||||
Checks: toProtocolChecks(s.accountManager, peerKey.String()),
|
||||
}
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
if err != nil {
|
||||
@@ -394,6 +412,31 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if
|
||||
// the token is valid.
|
||||
//
|
||||
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
|
||||
// registered or if it uses a setup key to register.
|
||||
func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
userID := ""
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
userID, err = s.validateToken(loginReq.GetJwtToken())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
|
||||
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case UDP:
|
||||
@@ -477,7 +520,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
|
||||
return remotePeers
|
||||
}
|
||||
|
||||
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
|
||||
func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
||||
|
||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||
@@ -508,6 +551,7 @@ func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCred
|
||||
FirewallRules: firewallRules,
|
||||
FirewallRulesIsEmpty: len(firewallRules) == 0,
|
||||
},
|
||||
Checks: toProtocolChecks(accountManager, peer.Key),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -526,7 +570,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
|
||||
} else {
|
||||
turnCredentials = nil
|
||||
}
|
||||
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
|
||||
plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
@@ -643,3 +687,67 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
|
||||
Body: encryptedResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
|
||||
// peer's under the same account of any updates.
|
||||
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
realIP := getRealIP(ctx)
|
||||
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
syncMetaReq := &proto.SyncMetaRequest{}
|
||||
peerKey, err := s.parseRequest(req, syncMetaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if syncMetaReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.Warn(msg)
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta()))
|
||||
if err != nil {
|
||||
return nil, mapError(err)
|
||||
}
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// toProtocolChecks returns posture checks for the peer that needs to be evaluated on the client side.
|
||||
func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Checks {
|
||||
postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
protoChecks := make([]*proto.Checks, 0)
|
||||
for _, postureCheck := range postureChecks {
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
}
|
||||
|
||||
return protoChecks
|
||||
}
|
||||
|
||||
// toProtocolCheck converts a posture.Checks to a proto.Checks.
|
||||
func toProtocolCheck(postureCheck posture.Checks) *proto.Checks {
|
||||
protoCheck := &proto.Checks{}
|
||||
|
||||
if check := postureCheck.Checks.ProcessCheck; check != nil {
|
||||
for _, process := range check.Processes {
|
||||
if process.LinuxPath != "" {
|
||||
protoCheck.Files = append(protoCheck.Files, process.LinuxPath)
|
||||
}
|
||||
if process.MacPath != "" {
|
||||
protoCheck.Files = append(protoCheck.Files, process.MacPath)
|
||||
}
|
||||
if process.WindowsPath != "" {
|
||||
protoCheck.Files = append(protoCheck.Files, process.WindowsPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return protoCheck
|
||||
}
|
||||
|
||||
@@ -817,6 +817,8 @@ components:
|
||||
$ref: '#/components/schemas/GeoLocationCheck'
|
||||
peer_network_range_check:
|
||||
$ref: '#/components/schemas/PeerNetworkRangeCheck'
|
||||
process_check:
|
||||
$ref: '#/components/schemas/ProcessCheck'
|
||||
NBVersionCheck:
|
||||
description: Posture check for the version of NetBird
|
||||
type: object
|
||||
@@ -905,6 +907,32 @@ components:
|
||||
required:
|
||||
- ranges
|
||||
- action
|
||||
ProcessCheck:
|
||||
description: Posture Check for binaries exist and are running in the peer’s system
|
||||
type: object
|
||||
properties:
|
||||
processes:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Process'
|
||||
required:
|
||||
- processes
|
||||
Process:
|
||||
description: Describes the operational activity within a peer's system.
|
||||
type: object
|
||||
properties:
|
||||
linux_path:
|
||||
description: Path to the process executable file in a Linux operating system
|
||||
type: string
|
||||
example: "/usr/local/bin/netbird"
|
||||
mac_path:
|
||||
description: Path to the process executable file in a Mac operating system
|
||||
type: string
|
||||
example: "/Applications/NetBird.app/Contents/MacOS/netbird"
|
||||
windows_path:
|
||||
description: Path to the process executable file in a Windows operating system
|
||||
type: string
|
||||
example: "C:\ProgramData\NetBird\netbird.exe"
|
||||
Location:
|
||||
description: Describe geographical location information
|
||||
type: object
|
||||
@@ -995,9 +1023,17 @@ components:
|
||||
type: string
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
network:
|
||||
description: Network range in CIDR format
|
||||
description: Network range in CIDR format, Conflicts with domains
|
||||
type: string
|
||||
example: 10.64.0.0/24
|
||||
domains:
|
||||
description: Domain list to be dynamically resolved. Conflicts with network
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
minLength: 1
|
||||
maxLength: 255
|
||||
example: "example.com"
|
||||
metric:
|
||||
description: Route metric number. Lowest number has higher priority
|
||||
type: integer
|
||||
@@ -1014,6 +1050,10 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "chacdk86lnnboviihd70"
|
||||
keep_route:
|
||||
description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- id
|
||||
- description
|
||||
@@ -1022,10 +1062,13 @@ components:
|
||||
# Only one property has to be set
|
||||
#- peer
|
||||
#- peer_groups
|
||||
- network
|
||||
# Only one property has to be set
|
||||
#- network
|
||||
#- domains
|
||||
- metric
|
||||
- masquerade
|
||||
- groups
|
||||
- keep_route
|
||||
Route:
|
||||
allOf:
|
||||
- type: object
|
||||
@@ -1035,7 +1078,7 @@ components:
|
||||
type: string
|
||||
example: chacdk86lnnboviihd7g
|
||||
network_type:
|
||||
description: Network type indicating if it is IPv4 or IPv6
|
||||
description: Network type indicating if it is a domain route or a IPv4/IPv6 route
|
||||
type: string
|
||||
example: IPv4
|
||||
required:
|
||||
|
||||
@@ -225,6 +225,9 @@ type Checks struct {
|
||||
|
||||
// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
|
||||
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"`
|
||||
|
||||
// ProcessCheck Posture Check for binaries exist and are running in the peer’s system
|
||||
ProcessCheck *ProcessCheck `json:"process_check,omitempty"`
|
||||
}
|
||||
|
||||
// City Describe city geographical location information
|
||||
@@ -949,11 +952,31 @@ type PostureCheckUpdate struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Process Describes the operational activity within a peer's system.
|
||||
type Process struct {
|
||||
// LinuxPath Path to the process executable file in a Linux operating system
|
||||
LinuxPath *string `json:"linux_path,omitempty"`
|
||||
|
||||
// MacPath Path to the process executable file in a Mac operating system
|
||||
MacPath *string `json:"mac_path,omitempty"`
|
||||
|
||||
// WindowsPath Path to the process executable file in a Windows operating system
|
||||
WindowsPath *string `json:"windows_path,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessCheck Posture Check for binaries exist and are running in the peer’s system
|
||||
type ProcessCheck struct {
|
||||
Processes []Process `json:"processes"`
|
||||
}
|
||||
|
||||
// Route defines model for Route.
|
||||
type Route struct {
|
||||
// Description Route description
|
||||
Description string `json:"description"`
|
||||
|
||||
// Domains Domain list to be dynamically resolved. Conflicts with network
|
||||
Domains *[]string `json:"domains,omitempty"`
|
||||
|
||||
// Enabled Route status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -963,19 +986,22 @@ type Route struct {
|
||||
// Id Route Id
|
||||
Id string `json:"id"`
|
||||
|
||||
// KeepRoute Indicate if the route should be kept after a domain doesn't resolve that IP anymore
|
||||
KeepRoute bool `json:"keep_route"`
|
||||
|
||||
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
|
||||
Masquerade bool `json:"masquerade"`
|
||||
|
||||
// Metric Route metric number. Lowest number has higher priority
|
||||
Metric int `json:"metric"`
|
||||
|
||||
// Network Network range in CIDR format
|
||||
Network string `json:"network"`
|
||||
// Network Network range in CIDR format, Conflicts with domains
|
||||
Network *string `json:"network,omitempty"`
|
||||
|
||||
// NetworkId Route network identifier, to group HA routes
|
||||
NetworkId string `json:"network_id"`
|
||||
|
||||
// NetworkType Network type indicating if it is IPv4 or IPv6
|
||||
// NetworkType Network type indicating if it is a domain route or a IPv4/IPv6 route
|
||||
NetworkType string `json:"network_type"`
|
||||
|
||||
// Peer Peer Identifier associated with route. This property can not be set together with `peer_groups`
|
||||
@@ -990,20 +1016,26 @@ type RouteRequest struct {
|
||||
// Description Route description
|
||||
Description string `json:"description"`
|
||||
|
||||
// Domains Domain list to be dynamically resolved. Conflicts with network
|
||||
Domains *[]string `json:"domains,omitempty"`
|
||||
|
||||
// Enabled Route status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Groups Group IDs containing routing peers
|
||||
Groups []string `json:"groups"`
|
||||
|
||||
// KeepRoute Indicate if the route should be kept after a domain doesn't resolve that IP anymore
|
||||
KeepRoute bool `json:"keep_route"`
|
||||
|
||||
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
|
||||
Masquerade bool `json:"masquerade"`
|
||||
|
||||
// Metric Route metric number. Lowest number has higher priority
|
||||
Metric int `json:"metric"`
|
||||
|
||||
// Network Network range in CIDR format
|
||||
Network string `json:"network"`
|
||||
// Network Network range in CIDR format, Conflicts with domains
|
||||
Network *string `json:"network,omitempty"`
|
||||
|
||||
// NetworkId Route network identifier, to group HA routes
|
||||
NetworkId string `json:"network_id"`
|
||||
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
@@ -13,6 +14,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
var (
|
||||
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
|
||||
)
|
||||
|
||||
// GeolocationsHandler is a handler that returns locations.
|
||||
type GeolocationsHandler struct {
|
||||
accountManager server.AccountManager
|
||||
@@ -73,8 +78,8 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
}
|
||||
|
||||
if l.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -27,12 +27,12 @@ const (
|
||||
notFoundUserID = "notFoundUserID"
|
||||
existingTokenID = "existingTokenID"
|
||||
notFoundTokenID = "notFoundTokenID"
|
||||
domain = "hotmail.com"
|
||||
testDomain = "hotmail.com"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: domain,
|
||||
Domain: testDomain,
|
||||
Users: map[string]*server.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
@@ -117,7 +117,7 @@ func initPATTestData() *PATHandler {
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: domain,
|
||||
Domain: testDomain,
|
||||
AccountId: testNSGroupAccountID,
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -3,8 +3,6 @@ package http
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
@@ -17,10 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
var (
|
||||
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
|
||||
)
|
||||
|
||||
// PostureChecksHandler is a handler that returns posture checks of the account.
|
||||
type PostureChecksHandler struct {
|
||||
accountManager server.AccountManager
|
||||
@@ -163,23 +157,20 @@ func (p *PostureChecksHandler) savePostureChecks(
|
||||
user *server.User,
|
||||
postureChecksID string,
|
||||
) {
|
||||
var (
|
||||
err error
|
||||
req api.PostureCheckUpdate
|
||||
)
|
||||
|
||||
var req api.PostureCheckUpdate
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err := validatePostureChecksUpdate(req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
|
||||
if p.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -197,69 +188,3 @@ func (p *PostureChecksHandler) savePostureChecks(
|
||||
|
||||
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
|
||||
}
|
||||
|
||||
func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
|
||||
if req.Name == "" {
|
||||
return status.Errorf(status.InvalidArgument, "posture checks name shouldn't be empty")
|
||||
}
|
||||
|
||||
if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil &&
|
||||
req.Checks.GeoLocationCheck == nil && req.Checks.PeerNetworkRangeCheck == nil) {
|
||||
return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty")
|
||||
}
|
||||
|
||||
if req.Checks.NbVersionCheck != nil && req.Checks.NbVersionCheck.MinVersion == "" {
|
||||
return status.Errorf(status.InvalidArgument, "minimum version for NetBird's version check shouldn't be empty")
|
||||
}
|
||||
|
||||
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
|
||||
emptyOS := osVersionCheck.Android == nil && osVersionCheck.Darwin == nil && osVersionCheck.Ios == nil &&
|
||||
osVersionCheck.Linux == nil && osVersionCheck.Windows == nil
|
||||
emptyMinVersion := osVersionCheck.Android != nil && osVersionCheck.Android.MinVersion == "" ||
|
||||
osVersionCheck.Darwin != nil && osVersionCheck.Darwin.MinVersion == "" ||
|
||||
osVersionCheck.Ios != nil && osVersionCheck.Ios.MinVersion == "" ||
|
||||
osVersionCheck.Linux != nil && osVersionCheck.Linux.MinKernelVersion == "" ||
|
||||
osVersionCheck.Windows != nil && osVersionCheck.Windows.MinKernelVersion == ""
|
||||
if emptyOS || emptyMinVersion {
|
||||
return status.Errorf(status.InvalidArgument,
|
||||
"minimum version for at least one OS in the OS version check shouldn't be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
|
||||
if geoLocationCheck.Action == "" {
|
||||
return status.Errorf(status.InvalidArgument, "action for geolocation check shouldn't be empty")
|
||||
}
|
||||
allowedActions := []api.GeoLocationCheckAction{api.GeoLocationCheckActionAllow, api.GeoLocationCheckActionDeny}
|
||||
if !slices.Contains(allowedActions, geoLocationCheck.Action) {
|
||||
return status.Errorf(status.InvalidArgument, "action for geolocation check is not valid value")
|
||||
}
|
||||
if len(geoLocationCheck.Locations) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "locations for geolocation check shouldn't be empty")
|
||||
}
|
||||
for _, loc := range geoLocationCheck.Locations {
|
||||
if loc.CountryCode == "" {
|
||||
return status.Errorf(status.InvalidArgument, "country code for geolocation check shouldn't be empty")
|
||||
}
|
||||
if !countryCodeRegex.MatchString(loc.CountryCode) {
|
||||
return status.Errorf(status.InvalidArgument, "country code must be 2 letters (ISO 3166-1 alpha-2 format)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
|
||||
if peerNetworkRangeCheck.Action == "" {
|
||||
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
|
||||
}
|
||||
|
||||
allowedActions := []api.PeerNetworkRangeCheckAction{api.PeerNetworkRangeCheckActionAllow, api.PeerNetworkRangeCheckActionDeny}
|
||||
if !slices.Contains(allowedActions, peerNetworkRangeCheck.Action) {
|
||||
return status.Errorf(status.InvalidArgument, "action for peer network range check is not valid value")
|
||||
}
|
||||
if len(peerNetworkRangeCheck.Ranges) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "network ranges for peer network range check shouldn't be empty")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -43,6 +43,11 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
|
||||
postureChecks.ID = "postureCheck"
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
|
||||
@@ -433,6 +438,45 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
handler.geolocationManager = nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create Posture Checks Process Check",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/posture-checks",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(`{
|
||||
"name": "default",
|
||||
"description": "default",
|
||||
"checks": {
|
||||
"process_check": {
|
||||
"processes": [
|
||||
{
|
||||
"linux_path": "/usr/local/bin/netbird",
|
||||
"mac_path": "/Applications/NetBird.app/Contents/MacOS/netbird",
|
||||
"windows_path": "C:\\ProgramData\\NetBird\\netbird.exe"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedPostureCheck: &api.PostureCheck{
|
||||
Id: "postureCheck",
|
||||
Name: "default",
|
||||
Description: str("default"),
|
||||
Checks: api.Checks{
|
||||
ProcessCheck: &api.ProcessCheck{
|
||||
Processes: []api.Process{
|
||||
{
|
||||
LinuxPath: str("/usr/local/bin/netbird"),
|
||||
MacPath: str("/Applications/NetBird.app/Contents/MacOS/netbird"),
|
||||
WindowsPath: str("C:\\ProgramData\\NetBird\\netbird.exe"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create Posture Checks Invalid Check",
|
||||
requestType: http.MethodPost,
|
||||
@@ -446,7 +490,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@@ -461,7 +505,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@@ -475,7 +519,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
"nb_version_check": {}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@@ -489,7 +533,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
"geo_location_check": {}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@@ -663,11 +707,8 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
setupHandlerFunc: func(handler *PostureChecksHandler) {
|
||||
handler.geolocationManager = nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update Posture Checks Invalid Check",
|
||||
@@ -682,7 +723,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@@ -697,7 +738,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@@ -711,7 +752,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
"nb_version_check": {}
|
||||
}
|
||||
}`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@@ -841,100 +882,3 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) {
|
||||
// empty name
|
||||
err := validatePostureChecksUpdate(api.PostureCheckUpdate{})
|
||||
assert.Error(t, err)
|
||||
|
||||
// empty checks
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default"})
|
||||
assert.Error(t, err)
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{}})
|
||||
assert.Error(t, err)
|
||||
|
||||
// not valid NbVersionCheck
|
||||
nbVersionCheck := api.NBVersionCheck{}
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
|
||||
assert.Error(t, err)
|
||||
|
||||
// valid NbVersionCheck
|
||||
nbVersionCheck = api.NBVersionCheck{MinVersion: "1.0"}
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// not valid OsVersionCheck
|
||||
osVersionCheck := api.OSVersionCheck{}
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
|
||||
assert.Error(t, err)
|
||||
|
||||
// not valid OsVersionCheck
|
||||
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}}
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
|
||||
assert.Error(t, err)
|
||||
|
||||
// not valid OsVersionCheck
|
||||
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}, Darwin: &api.MinVersionCheck{MinVersion: "14.2"}}
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
|
||||
assert.Error(t, err)
|
||||
|
||||
// valid OsVersionCheck
|
||||
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"}}
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// valid OsVersionCheck
|
||||
osVersionCheck = api.OSVersionCheck{
|
||||
Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"},
|
||||
Darwin: &api.MinVersionCheck{MinVersion: "14.2"},
|
||||
}
|
||||
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// valid peer network range check
|
||||
peerNetworkRangeCheck := api.PeerNetworkRangeCheck{
|
||||
Action: api.PeerNetworkRangeCheckActionAllow,
|
||||
Ranges: []string{
|
||||
"192.168.1.0/24", "10.0.0.0/8",
|
||||
},
|
||||
}
|
||||
err = validatePostureChecksUpdate(
|
||||
api.PostureCheckUpdate{
|
||||
Name: "Default",
|
||||
Checks: &api.Checks{
|
||||
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
|
||||
},
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// invalid peer network range check
|
||||
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
|
||||
Action: api.PeerNetworkRangeCheckActionDeny,
|
||||
Ranges: []string{},
|
||||
}
|
||||
err = validatePostureChecksUpdate(
|
||||
api.PostureCheckUpdate{
|
||||
Name: "Default",
|
||||
Checks: &api.Checks{
|
||||
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
|
||||
},
|
||||
},
|
||||
)
|
||||
assert.Error(t, err)
|
||||
|
||||
// invalid peer network range check
|
||||
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
|
||||
Action: "unknownAction",
|
||||
Ranges: []string{},
|
||||
}
|
||||
err = validatePostureChecksUpdate(
|
||||
api.PostureCheckUpdate{
|
||||
Name: "Default",
|
||||
Checks: &api.Checks{
|
||||
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
|
||||
},
|
||||
},
|
||||
)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -2,11 +2,16 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
@@ -15,6 +20,9 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const maxDomains = 32
|
||||
const failedToConvertRoute = "failed to convert route to response: %v"
|
||||
|
||||
// RoutesHandler is the routes handler of the account
|
||||
type RoutesHandler struct {
|
||||
accountManager server.AccountManager
|
||||
@@ -48,7 +56,12 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
apiRoutes := make([]*api.Route, 0)
|
||||
for _, r := range routes {
|
||||
apiRoutes = append(apiRoutes, toRouteResponse(r))
|
||||
route, err := toRouteResponse(r)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
apiRoutes = append(apiRoutes, route)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiRoutes)
|
||||
@@ -70,16 +83,28 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, newPrefix, err := route.ParseNetwork(req.Network)
|
||||
if err != nil {
|
||||
if err := h.validateRoute(req); err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d",
|
||||
route.MaxNetIDChar), w)
|
||||
return
|
||||
var domains domain.List
|
||||
var networkType route.NetworkType
|
||||
var newPrefix netip.Prefix
|
||||
if req.Domains != nil {
|
||||
d, err := validateDomains(*req.Domains)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
return
|
||||
}
|
||||
domains = d
|
||||
networkType = route.DomainNetwork
|
||||
} else if req.Network != nil {
|
||||
networkType, newPrefix, err = route.ParseNetwork(*req.Network)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
peerId := ""
|
||||
@@ -87,36 +112,57 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
peerId = *req.Peer
|
||||
}
|
||||
|
||||
peerGroupIds := []string{}
|
||||
var peerGroupIds []string
|
||||
if req.PeerGroups != nil {
|
||||
peerGroupIds = *req.PeerGroups
|
||||
}
|
||||
|
||||
if (peerId != "" && len(peerGroupIds) > 0) || (peerId == "" && len(peerGroupIds) == 0) {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "only one peer or peer_groups should be provided"), w)
|
||||
return
|
||||
}
|
||||
|
||||
// do not allow non Linux peers
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(peerId); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(
|
||||
account.Id, newPrefix.String(), peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
|
||||
)
|
||||
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toRouteResponse(newRoute)
|
||||
routes, err := toRouteResponse(newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, routes)
|
||||
}
|
||||
|
||||
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
|
||||
if req.Network != nil && req.Domains != nil {
|
||||
return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided")
|
||||
}
|
||||
|
||||
if req.Network == nil && req.Domains == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'network' or 'domains' should be provided")
|
||||
}
|
||||
|
||||
if req.Peer == nil && req.PeerGroups == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
|
||||
}
|
||||
|
||||
if req.Peer != nil && req.PeerGroups != nil {
|
||||
return status.Errorf(status.InvalidArgument, "only one of 'peer' or 'peer_groups' should be provided")
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d characters",
|
||||
route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoute handles update to a route identified by a given ID
|
||||
@@ -148,26 +194,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
prefixType, newPrefix, err := route.ParseNetwork(req.Network)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s",
|
||||
req.Network, routeID), w)
|
||||
return
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"identifier should be between 1 and %d", route.MaxNetIDChar), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Peer != nil && req.PeerGroups != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "only peer or peers_group should be provided"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Peer == nil && req.PeerGroups == nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "either peer or peers_group should be provided"), w)
|
||||
if err := h.validateRoute(req); err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -186,14 +214,29 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID(routeID),
|
||||
Network: newPrefix,
|
||||
NetID: route.NetID(req.NetworkId),
|
||||
NetworkType: prefixType,
|
||||
Masquerade: req.Masquerade,
|
||||
Metric: req.Metric,
|
||||
Description: req.Description,
|
||||
Enabled: req.Enabled,
|
||||
Groups: req.Groups,
|
||||
KeepRoute: req.KeepRoute,
|
||||
}
|
||||
|
||||
if req.Domains != nil {
|
||||
d, err := validateDomains(*req.Domains)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
return
|
||||
}
|
||||
newRoute.Domains = d
|
||||
newRoute.NetworkType = route.DomainNetwork
|
||||
} else if req.Network != nil {
|
||||
newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.Peer != nil {
|
||||
@@ -210,9 +253,13 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toRouteResponse(newRoute)
|
||||
routes, err := toRouteResponse(newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, routes)
|
||||
}
|
||||
|
||||
// DeleteRoute handles route deletion request
|
||||
@@ -260,25 +307,69 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toRouteResponse(foundRoute))
|
||||
routes, err := toRouteResponse(foundRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, routes)
|
||||
}
|
||||
|
||||
func toRouteResponse(serverRoute *route.Route) *api.Route {
|
||||
func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
|
||||
domains, err := serverRoute.Domains.ToStringList()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
network := serverRoute.Network.String()
|
||||
route := &api.Route{
|
||||
Id: string(serverRoute.ID),
|
||||
Description: serverRoute.Description,
|
||||
NetworkId: string(serverRoute.NetID),
|
||||
Enabled: serverRoute.Enabled,
|
||||
Peer: &serverRoute.Peer,
|
||||
Network: serverRoute.Network.String(),
|
||||
Network: &network,
|
||||
Domains: &domains,
|
||||
NetworkType: serverRoute.NetworkType.String(),
|
||||
Masquerade: serverRoute.Masquerade,
|
||||
Metric: serverRoute.Metric,
|
||||
Groups: serverRoute.Groups,
|
||||
KeepRoute: serverRoute.KeepRoute,
|
||||
}
|
||||
|
||||
if len(serverRoute.PeerGroups) > 0 {
|
||||
route.PeerGroups = &serverRoute.PeerGroups
|
||||
}
|
||||
return route
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// validateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
|
||||
func validateDomains(domains []string) (domain.List, error) {
|
||||
if len(domains) == 0 {
|
||||
return nil, fmt.Errorf("domains list is empty")
|
||||
}
|
||||
if len(domains) > maxDomains {
|
||||
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||
}
|
||||
|
||||
domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||
|
||||
var domainList domain.List
|
||||
|
||||
for _, d := range domains {
|
||||
d := strings.ToLower(d)
|
||||
|
||||
// handles length and idna conversion
|
||||
punycode, err := domain.FromString(d)
|
||||
if err != nil {
|
||||
return domainList, fmt.Errorf("failed to convert domain to punycode: %s: %v", d, err)
|
||||
}
|
||||
|
||||
if !domainRegex.MatchString(string(punycode)) {
|
||||
return domainList, fmt.Errorf("invalid domain format: %s", d)
|
||||
}
|
||||
|
||||
domainList = append(domainList, punycode)
|
||||
}
|
||||
return domainList, nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/magiconair/properties/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
@@ -26,6 +29,7 @@ import (
|
||||
const (
|
||||
existingRouteID = "existingRouteID"
|
||||
existingRouteID2 = "existingRouteID2" // for peer_groups test
|
||||
existingRouteID3 = "existingRouteID3" // for domains test
|
||||
notFoundRouteID = "notFoundRouteID"
|
||||
existingPeerIP1 = "100.64.0.100"
|
||||
existingPeerIP2 = "100.64.0.101"
|
||||
@@ -35,6 +39,7 @@ const (
|
||||
testAccountID = "test_id"
|
||||
existingGroupID = "testGroup"
|
||||
notFoundGroupID = "nonExistingGroup"
|
||||
existingDomain = "example.com"
|
||||
)
|
||||
|
||||
var emptyString = ""
|
||||
@@ -46,6 +51,8 @@ var baseExistingRoute = &route.Route{
|
||||
Description: "base route",
|
||||
NetID: "awesomeNet",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
Domains: domain.List{},
|
||||
KeepRoute: false,
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -90,28 +97,33 @@ func initRoutesTestData() *RoutesHandler {
|
||||
route := baseExistingRoute.Copy()
|
||||
route.PeerGroups = []string{existingGroupID}
|
||||
return route, nil
|
||||
} else if routeID == existingRouteID3 {
|
||||
route := baseExistingRoute.Copy()
|
||||
route.Domains = domain.List{existingDomain}
|
||||
return route, nil
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
},
|
||||
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
|
||||
CreateRouteFunc: func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
||||
if peerID == notFoundPeerID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
|
||||
}
|
||||
networkType, p, _ := route.ParseNetwork(network)
|
||||
return &route.Route{
|
||||
ID: existingRouteID,
|
||||
NetID: netID,
|
||||
Peer: peerID,
|
||||
PeerGroups: peerGroups,
|
||||
Network: p,
|
||||
Network: prefix,
|
||||
Domains: domains,
|
||||
NetworkType: networkType,
|
||||
Description: description,
|
||||
Masquerade: masquerade,
|
||||
Enabled: enabled,
|
||||
Groups: groups,
|
||||
KeepRoute: keepRoute,
|
||||
}, nil
|
||||
},
|
||||
SaveRouteFunc: func(_, _ string, r *route.Route) error {
|
||||
@@ -146,6 +158,9 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
baseExistingRouteWithPeerGroups := baseExistingRoute.Copy()
|
||||
baseExistingRouteWithPeerGroups.PeerGroups = []string{existingGroupID}
|
||||
|
||||
baseExistingRouteWithDomains := baseExistingRoute.Copy()
|
||||
baseExistingRouteWithDomains.Domains = domain.List{existingDomain}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
@@ -161,7 +176,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: toRouteResponse(baseExistingRoute),
|
||||
expectedRoute: toApiRoute(t, baseExistingRoute),
|
||||
},
|
||||
{
|
||||
name: "Get Not Existing Route",
|
||||
@@ -175,7 +190,15 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestPath: "/api/routes/" + existingRouteID2,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: toRouteResponse(baseExistingRouteWithPeerGroups),
|
||||
expectedRoute: toApiRoute(t, baseExistingRouteWithPeerGroups),
|
||||
},
|
||||
{
|
||||
name: "Get Existing Route with Domains",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/routes/" + existingRouteID3,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: toApiRoute(t, baseExistingRouteWithDomains),
|
||||
},
|
||||
{
|
||||
name: "Delete Existing Route",
|
||||
@@ -191,18 +214,18 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "POST OK",
|
||||
name: "Network POST OK",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID))),
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: "192.168.0.0/16",
|
||||
Network: toPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
@@ -210,6 +233,28 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Domains POST OK",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "domainNet",
|
||||
Network: toPtr("invalid Prefix"),
|
||||
KeepRoute: true,
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "POST Non Linux Peer",
|
||||
requestType: http.MethodPost,
|
||||
@@ -242,6 +287,32 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "POST Invalid Domains",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["-example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "POST UnprocessableEntity when both network and domains are provided",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","domains":["example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "POST UnprocessableEntity when no network and domains are provided",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","network_id":"awesomeNet","groups":["%s"]}`, existingPeerID))),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "POST UnprocessableEntity when both peer and peer_groups are provided",
|
||||
requestType: http.MethodPost,
|
||||
@@ -261,7 +332,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "PUT OK",
|
||||
name: "Network PUT OK",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID)),
|
||||
@@ -271,7 +342,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: "192.168.0.0/16",
|
||||
Network: toPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
@@ -279,6 +350,27 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Domains PUT OK",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: toPtr("invalid Prefix"),
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
KeepRoute: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "PUT OK when peer_groups provided",
|
||||
requestType: http.MethodPut,
|
||||
@@ -290,7 +382,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: "192.168.0.0/16",
|
||||
Network: toPtr("192.168.0.0/16"),
|
||||
Peer: &emptyString,
|
||||
PeerGroups: &[]string{existingGroupID},
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
@@ -339,6 +431,33 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "PUT Invalid Domains",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","domains":["-example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "PUT UnprocessableEntity when both network and domains are provided",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","domains":["example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "PUT UnprocessableEntity when no network and domains are provided",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "PUT UnprocessableEntity when both peer and peer_groups are provided",
|
||||
requestType: http.MethodPut,
|
||||
@@ -399,3 +518,85 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
expected domain.List
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty list",
|
||||
domains: nil,
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Valid ASCII domain",
|
||||
domains: []string{"sub.ex-ample.com"},
|
||||
expected: domain.List{"sub.ex-ample.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Unicode domain",
|
||||
domains: []string{"münchen.de"},
|
||||
expected: domain.List{"xn--mnchen-3ya.de"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Unicode, all labels",
|
||||
domains: []string{"中国.中国.中国"},
|
||||
expected: domain.List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "With underscores",
|
||||
domains: []string{"_jabber._tcp.gmail.com"},
|
||||
expected: domain.List{"_jabber._tcp.gmail.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain format",
|
||||
domains: []string{"-example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain format 2",
|
||||
domains: []string{"example.com-"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple domains valid and invalid",
|
||||
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
|
||||
expected: domain.List{"google.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := validateDomains(tt.domains)
|
||||
assert.Equal(t, tt.wantErr, err != nil)
|
||||
assert.Equal(t, got, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func toApiRoute(t *testing.T, r *route.Route) *api.Route {
|
||||
t.Helper()
|
||||
|
||||
apiRoute, err := toRouteResponse(r)
|
||||
// json flattens pointer to nil slices to null
|
||||
if apiRoute.Domains != nil && *apiRoute.Domains == nil {
|
||||
apiRoute.Domains = nil
|
||||
}
|
||||
require.NoError(t, err, "Failed to convert route")
|
||||
return apiRoute
|
||||
}
|
||||
|
||||
func toPtr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ const (
|
||||
|
||||
var usersTestAccount = &server.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: domain,
|
||||
Domain: testDomain,
|
||||
Users: map[string]*server.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
@@ -127,7 +127,7 @@ func initUsersTestData() *UsersHandler {
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: domain,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -134,7 +134,8 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
// take the first registered peer as a base for the test. Total four.
|
||||
key := *peers[0]
|
||||
|
||||
message, err := encryption.EncryptMessage(*serverKey, key, &mgmtProto.SyncRequest{})
|
||||
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
message, err := encryption.EncryptMessage(*serverKey, key, syncReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
||||
@@ -93,7 +93,8 @@ var _ = Describe("Management service", func() {
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
loginPeerWithValidSetupKey(serverPubKey, key, client)
|
||||
|
||||
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{})
|
||||
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
|
||||
@@ -143,7 +144,7 @@ var _ = Describe("Management service", func() {
|
||||
loginPeerWithValidSetupKey(serverPubKey, key1, client)
|
||||
loginPeerWithValidSetupKey(serverPubKey, key2, client)
|
||||
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@@ -176,7 +177,7 @@ var _ = Describe("Management service", func() {
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
loginPeerWithValidSetupKey(serverPubKey, key, client)
|
||||
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@@ -329,7 +330,7 @@ var _ = Describe("Management service", func() {
|
||||
|
||||
var clients []mgmtProto.ManagementService_SyncClient
|
||||
for _, peer := range peers {
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@@ -394,7 +395,8 @@ var _ = Describe("Management service", func() {
|
||||
defer GinkgoRecover()
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
loginPeerWithValidSetupKey(serverPubKey, key, client)
|
||||
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{})
|
||||
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// open stream
|
||||
|
||||
@@ -2,12 +2,14 @@ package mock_server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
@@ -28,7 +30,7 @@ type MockAccountManager struct {
|
||||
ListUsersFunc func(accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
@@ -52,7 +54,7 @@ type MockAccountManager struct {
|
||||
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||
CreateRouteFunc func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
|
||||
@@ -81,6 +83,7 @@ type MockAccountManager struct {
|
||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
GetPeerAppliedPostureChecksFunc func(peerKey string) ([]posture.Checks, error)
|
||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
@@ -95,12 +98,13 @@ type MockAccountManager struct {
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(peerPubKey, realIP)
|
||||
return am.SyncAndMarkPeerFunc(peerPubKey, meta, realIP)
|
||||
}
|
||||
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
@@ -413,9 +417,9 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.
|
||||
}
|
||||
|
||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
||||
func (am *MockAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
if am.CreateRouteFunc != nil {
|
||||
return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID)
|
||||
return am.CreateRouteFunc(accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
||||
}
|
||||
@@ -623,6 +627,14 @@ func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
|
||||
}
|
||||
|
||||
// GetPeerAppliedPostureChecks mocks GetPeerAppliedPostureChecks of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
|
||||
if am.GetPeerAppliedPostureChecksFunc != nil {
|
||||
return am.GetPeerAppliedPostureChecksFunc(peerKey)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerAppliedPostureChecks is not implemented")
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
if am.UpdateAccountSettingsFunc != nil {
|
||||
@@ -736,6 +748,14 @@ func (am *MockAccountManager) GroupValidation(accountId string, groups []string)
|
||||
return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented")
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
if am.SyncPeerMetaFunc != nil {
|
||||
return am.SyncPeerMetaFunc(peerPubKey, meta)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
|
||||
}
|
||||
|
||||
// FindExistingPostureCheck mocks FindExistingPostureCheck of the AccountManager interface
|
||||
func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
if am.FindExistingPostureCheckFunc != nil {
|
||||
|
||||
@@ -3,9 +3,10 @@ package mock_server
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
type ManagementServiceServerMock struct {
|
||||
@@ -17,6 +18,7 @@ type ManagementServiceServerMock struct {
|
||||
IsHealthyFunc func(context.Context, *proto.Empty) (*proto.Empty, error)
|
||||
GetDeviceAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
|
||||
GetPKCEAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
|
||||
SyncMetaFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error)
|
||||
}
|
||||
|
||||
func (m ManagementServiceServerMock) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
@@ -60,3 +62,10 @@ func (m ManagementServiceServerMock) GetPKCEAuthorizationFlow(ctx context.Contex
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
|
||||
}
|
||||
|
||||
func (m ManagementServiceServerMock) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
if m.SyncMetaFunc != nil {
|
||||
return m.SyncMetaFunc(ctx, req)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
|
||||
}
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type PeerSync struct {
|
||||
// WireGuardPubKey is a peers WireGuard public key
|
||||
WireGuardPubKey string
|
||||
// Meta is the system information passed by peer, must be always present
|
||||
Meta nbpeer.PeerSystemMeta
|
||||
// UpdateAccountPeers indicate updating account peers,
|
||||
// which occurs when the peer's metadata is updated
|
||||
UpdateAccountPeers bool
|
||||
}
|
||||
|
||||
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
|
||||
@@ -528,6 +533,18 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
|
||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
|
||||
peer, updated := updatePeerMeta(peer, sync.Meta, account)
|
||||
if updated {
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if sync.UpdateAccountPeers {
|
||||
am.updateAccountPeers(account)
|
||||
}
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -900,7 +917,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
|
||||
continue
|
||||
}
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
||||
update := toSyncResponse(am, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
||||
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -79,6 +80,13 @@ type Environment struct {
|
||||
Platform string
|
||||
}
|
||||
|
||||
// File is a file on the system.
|
||||
type File struct {
|
||||
Path string
|
||||
Exist bool
|
||||
ProcessIsRunning bool
|
||||
}
|
||||
|
||||
// PeerSystemMeta is a metadata of a Peer machine system
|
||||
type PeerSystemMeta struct { //nolint:revive
|
||||
Hostname string
|
||||
@@ -96,24 +104,22 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment `gorm:"serializer:json"`
|
||||
Files []File `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
if len(p.NetworkAddresses) != len(other.NetworkAddresses) {
|
||||
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
|
||||
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
|
||||
})
|
||||
if !equalNetworkAddresses {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, addr := range p.NetworkAddresses {
|
||||
var found bool
|
||||
for _, oAddr := range other.NetworkAddresses {
|
||||
if addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
|
||||
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
|
||||
})
|
||||
if !equalFiles {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.Hostname == other.Hostname &&
|
||||
@@ -133,6 +139,26 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
p.Environment.Platform == other.Environment.Platform
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEmpty() bool {
|
||||
return p.Hostname == "" &&
|
||||
p.GoOS == "" &&
|
||||
p.Kernel == "" &&
|
||||
p.Core == "" &&
|
||||
p.Platform == "" &&
|
||||
p.OS == "" &&
|
||||
p.OSVersion == "" &&
|
||||
p.WtVersion == "" &&
|
||||
p.UIVersion == "" &&
|
||||
p.KernelVersion == "" &&
|
||||
len(p.NetworkAddresses) == 0 &&
|
||||
p.SystemSerialNumber == "" &&
|
||||
p.SystemProductName == "" &&
|
||||
p.SystemManufacturer == "" &&
|
||||
p.Environment.Cloud == "" &&
|
||||
p.Environment.Platform == "" &&
|
||||
len(p.Files) == 0
|
||||
}
|
||||
|
||||
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
|
||||
func (p *Peer) AddedWithSSOLogin() bool {
|
||||
return p.UserID != ""
|
||||
@@ -168,6 +194,10 @@ func (p *Peer) Copy() *Peer {
|
||||
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||
// returns true if meta was updated, false otherwise
|
||||
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
|
||||
if meta.isEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||
if meta.UIVersion == "" {
|
||||
meta.UIVersion = p.Meta.UIVersion
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/rs/xid"
|
||||
@@ -17,15 +18,21 @@ const (
|
||||
OSVersionCheckName = "OSVersionCheck"
|
||||
GeoLocationCheckName = "GeoLocationCheck"
|
||||
PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
|
||||
ProcessCheckName = "ProcessCheck"
|
||||
|
||||
CheckActionAllow string = "allow"
|
||||
CheckActionDeny string = "deny"
|
||||
)
|
||||
|
||||
var (
|
||||
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
|
||||
)
|
||||
|
||||
// Check represents an interface for performing a check on a peer.
|
||||
type Check interface {
|
||||
Check(peer nbpeer.Peer) (bool, error)
|
||||
Name() string
|
||||
Check(peer nbpeer.Peer) (bool, error)
|
||||
Validate() error
|
||||
}
|
||||
|
||||
type Checks struct {
|
||||
@@ -51,6 +58,7 @@ type ChecksDefinition struct {
|
||||
OSVersionCheck *OSVersionCheck `json:",omitempty"`
|
||||
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
|
||||
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
|
||||
ProcessCheck *ProcessCheck `json:",omitempty"`
|
||||
}
|
||||
|
||||
// Copy returns a copy of a checks definition.
|
||||
@@ -96,6 +104,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
|
||||
}
|
||||
copy(cdCopy.PeerNetworkRangeCheck.Ranges, peerNetRangeCheck.Ranges)
|
||||
}
|
||||
if cd.ProcessCheck != nil {
|
||||
processCheck := cd.ProcessCheck
|
||||
cdCopy.ProcessCheck = &ProcessCheck{
|
||||
Processes: make([]Process, len(processCheck.Processes)),
|
||||
}
|
||||
copy(cdCopy.ProcessCheck.Processes, processCheck.Processes)
|
||||
}
|
||||
return cdCopy
|
||||
}
|
||||
|
||||
@@ -136,6 +151,9 @@ func (pc *Checks) GetChecks() []Check {
|
||||
if pc.Checks.PeerNetworkRangeCheck != nil {
|
||||
checks = append(checks, pc.Checks.PeerNetworkRangeCheck)
|
||||
}
|
||||
if pc.Checks.ProcessCheck != nil {
|
||||
checks = append(checks, pc.Checks.ProcessCheck)
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
@@ -191,6 +209,10 @@ func buildPostureCheck(postureChecksID string, name string, description string,
|
||||
}
|
||||
}
|
||||
|
||||
if processCheck := checks.ProcessCheck; processCheck != nil {
|
||||
postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck)
|
||||
}
|
||||
|
||||
return &postureChecks, nil
|
||||
}
|
||||
|
||||
@@ -221,6 +243,10 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck {
|
||||
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(pc.Checks.PeerNetworkRangeCheck)
|
||||
}
|
||||
|
||||
if pc.Checks.ProcessCheck != nil {
|
||||
checks.ProcessCheck = toProcessCheckResponse(pc.Checks.ProcessCheck)
|
||||
}
|
||||
|
||||
return &api.PostureCheck{
|
||||
Id: pc.ID,
|
||||
Name: pc.Name,
|
||||
@@ -229,44 +255,20 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks the validity of a posture checks.
|
||||
func (pc *Checks) Validate() error {
|
||||
if check := pc.Checks.NBVersionCheck; check != nil {
|
||||
if !isVersionValid(check.MinVersion) {
|
||||
return fmt.Errorf("%s version: %s is not valid", check.Name(), check.MinVersion)
|
||||
}
|
||||
if pc.Name == "" {
|
||||
return errors.New("posture checks name shouldn't be empty")
|
||||
}
|
||||
|
||||
if osCheck := pc.Checks.OSVersionCheck; osCheck != nil {
|
||||
if osCheck.Android != nil {
|
||||
if !isVersionValid(osCheck.Android.MinVersion) {
|
||||
return fmt.Errorf("%s android version: %s is not valid", osCheck.Name(), osCheck.Android.MinVersion)
|
||||
}
|
||||
}
|
||||
checks := pc.GetChecks()
|
||||
if len(checks) == 0 {
|
||||
return errors.New("posture checks shouldn't be empty")
|
||||
}
|
||||
|
||||
if osCheck.Ios != nil {
|
||||
if !isVersionValid(osCheck.Ios.MinVersion) {
|
||||
return fmt.Errorf("%s ios version: %s is not valid", osCheck.Name(), osCheck.Ios.MinVersion)
|
||||
}
|
||||
}
|
||||
|
||||
if osCheck.Darwin != nil {
|
||||
if !isVersionValid(osCheck.Darwin.MinVersion) {
|
||||
return fmt.Errorf("%s darwin version: %s is not valid", osCheck.Name(), osCheck.Darwin.MinVersion)
|
||||
}
|
||||
}
|
||||
|
||||
if osCheck.Linux != nil {
|
||||
if !isVersionValid(osCheck.Linux.MinKernelVersion) {
|
||||
return fmt.Errorf("%s linux kernel version: %s is not valid", osCheck.Name(),
|
||||
osCheck.Linux.MinKernelVersion)
|
||||
}
|
||||
}
|
||||
|
||||
if osCheck.Windows != nil {
|
||||
if !isVersionValid(osCheck.Windows.MinKernelVersion) {
|
||||
return fmt.Errorf("%s windows kernel version: %s is not valid", osCheck.Name(),
|
||||
osCheck.Windows.MinKernelVersion)
|
||||
}
|
||||
for _, check := range checks {
|
||||
if err := check.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,3 +354,40 @@ func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*PeerNetworkRang
|
||||
Action: string(check.Action),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func toProcessCheckResponse(check *ProcessCheck) *api.ProcessCheck {
|
||||
processes := make([]api.Process, 0, len(check.Processes))
|
||||
for i := range check.Processes {
|
||||
processes = append(processes, api.Process{
|
||||
LinuxPath: &check.Processes[i].LinuxPath,
|
||||
MacPath: &check.Processes[i].MacPath,
|
||||
WindowsPath: &check.Processes[i].WindowsPath,
|
||||
})
|
||||
}
|
||||
|
||||
return &api.ProcessCheck{
|
||||
Processes: processes,
|
||||
}
|
||||
}
|
||||
|
||||
func toProcessCheck(check *api.ProcessCheck) *ProcessCheck {
|
||||
processes := make([]Process, 0, len(check.Processes))
|
||||
for _, process := range check.Processes {
|
||||
var p Process
|
||||
if process.LinuxPath != nil {
|
||||
p.LinuxPath = *process.LinuxPath
|
||||
}
|
||||
if process.MacPath != nil {
|
||||
p.MacPath = *process.MacPath
|
||||
}
|
||||
if process.WindowsPath != nil {
|
||||
p.WindowsPath = *process.WindowsPath
|
||||
}
|
||||
|
||||
processes = append(processes, p)
|
||||
}
|
||||
|
||||
return &ProcessCheck{
|
||||
Processes: processes,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,9 +150,23 @@ func TestChecks_Validate(t *testing.T) {
|
||||
checks Checks
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Empty name",
|
||||
checks: Checks{},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty checks",
|
||||
checks: Checks{
|
||||
Name: "Default",
|
||||
Checks: ChecksDefinition{},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid checks version",
|
||||
checks: Checks{
|
||||
Name: "default",
|
||||
Checks: ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{
|
||||
MinVersion: "0.25.0",
|
||||
@@ -261,6 +275,14 @@ func TestChecks_Copy(t *testing.T) {
|
||||
},
|
||||
Action: CheckActionDeny,
|
||||
},
|
||||
ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{
|
||||
{
|
||||
MacPath: "/Applications/NetBird.app/Contents/MacOS/netbird",
|
||||
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
checkCopy := check.Copy()
|
||||
|
||||
@@ -2,6 +2,7 @@ package posture
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -60,3 +61,28 @@ func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
func (g *GeoLocationCheck) Name() string {
|
||||
return GeoLocationCheckName
|
||||
}
|
||||
|
||||
func (g *GeoLocationCheck) Validate() error {
|
||||
if g.Action == "" {
|
||||
return fmt.Errorf("%s action shouldn't be empty", g.Name())
|
||||
}
|
||||
|
||||
allowedActions := []string{CheckActionAllow, CheckActionDeny}
|
||||
if !slices.Contains(allowedActions, g.Action) {
|
||||
return fmt.Errorf("%s action is not valid", g.Name())
|
||||
}
|
||||
|
||||
if len(g.Locations) == 0 {
|
||||
return fmt.Errorf("%s locations shouldn't be empty", g.Name())
|
||||
}
|
||||
|
||||
for _, loc := range g.Locations {
|
||||
if loc.CountryCode == "" {
|
||||
return fmt.Errorf("%s country code shouldn't be empty", g.Name())
|
||||
}
|
||||
if !countryCodeRegex.MatchString(loc.CountryCode) {
|
||||
return fmt.Errorf("%s country code must be 2 letters (ISO 3166-1 alpha-2 format)", g.Name())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -236,3 +236,81 @@ func TestGeoLocationCheck_Check(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoLocationCheck_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
check GeoLocationCheck
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid location list",
|
||||
check: GeoLocationCheck{
|
||||
Action: CheckActionAllow,
|
||||
Locations: []Location{
|
||||
{
|
||||
CountryCode: "DE",
|
||||
CityName: "Berlin",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid empty location list",
|
||||
check: GeoLocationCheck{
|
||||
Action: CheckActionDeny,
|
||||
Locations: []Location{},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid empty country name",
|
||||
check: GeoLocationCheck{
|
||||
Action: CheckActionDeny,
|
||||
Locations: []Location{
|
||||
{
|
||||
CityName: "Los Angeles",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid check action",
|
||||
check: GeoLocationCheck{
|
||||
Action: "unknownAction",
|
||||
Locations: []Location{
|
||||
{
|
||||
CountryCode: "DE",
|
||||
CityName: "Berlin",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid country code",
|
||||
check: GeoLocationCheck{
|
||||
Action: CheckActionAllow,
|
||||
Locations: []Location{
|
||||
{
|
||||
CountryCode: "USA",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.check.Validate()
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -37,3 +39,13 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
func (n *NBVersionCheck) Name() string {
|
||||
return NBVersionCheckName
|
||||
}
|
||||
|
||||
func (n *NBVersionCheck) Validate() error {
|
||||
if n.MinVersion == "" {
|
||||
return fmt.Errorf("%s minimum version shouldn't be empty", n.Name())
|
||||
}
|
||||
if !isVersionValid(n.MinVersion) {
|
||||
return fmt.Errorf("%s version: %s is not valid", n.Name(), n.MinVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -108,3 +108,33 @@ func TestNBVersionCheck_Check(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNBVersionCheck_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
check NBVersionCheck
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid NBVersionCheck",
|
||||
check: NBVersionCheck{MinVersion: "1.0"},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid NBVersionCheck",
|
||||
check: NBVersionCheck{},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.check.Validate()
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"slices"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
type PeerNetworkRangeCheck struct {
|
||||
@@ -52,3 +53,19 @@ func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
func (p *PeerNetworkRangeCheck) Name() string {
|
||||
return PeerNetworkRangeCheckName
|
||||
}
|
||||
|
||||
func (p *PeerNetworkRangeCheck) Validate() error {
|
||||
if p.Action == "" {
|
||||
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
|
||||
}
|
||||
|
||||
allowedActions := []string{CheckActionAllow, CheckActionDeny}
|
||||
if !slices.Contains(allowedActions, p.Action) {
|
||||
return fmt.Errorf("%s action is not valid", p.Name())
|
||||
}
|
||||
|
||||
if len(p.Ranges) == 0 {
|
||||
return fmt.Errorf("%s network ranges shouldn't be empty", p.Name())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -147,3 +147,52 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkCheck_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
check PeerNetworkRangeCheck
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid network range",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid empty network range",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionDeny,
|
||||
Ranges: []netip.Prefix{},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid check action",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: "unknownAction",
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.check.Validate()
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
type MinVersionCheck struct {
|
||||
@@ -48,6 +50,35 @@ func (c *OSVersionCheck) Name() string {
|
||||
return OSVersionCheckName
|
||||
}
|
||||
|
||||
func (c *OSVersionCheck) Validate() error {
|
||||
if c.Android == nil && c.Darwin == nil && c.Ios == nil && c.Linux == nil && c.Windows == nil {
|
||||
return fmt.Errorf("%s at least one OS version check is required", c.Name())
|
||||
}
|
||||
|
||||
if c.Android != nil && !isVersionValid(c.Android.MinVersion) {
|
||||
return fmt.Errorf("%s android version: %s is not valid", c.Name(), c.Android.MinVersion)
|
||||
}
|
||||
|
||||
if c.Ios != nil && !isVersionValid(c.Ios.MinVersion) {
|
||||
return fmt.Errorf("%s ios version: %s is not valid", c.Name(), c.Ios.MinVersion)
|
||||
}
|
||||
|
||||
if c.Darwin != nil && !isVersionValid(c.Darwin.MinVersion) {
|
||||
return fmt.Errorf("%s darwin version: %s is not valid", c.Name(), c.Darwin.MinVersion)
|
||||
}
|
||||
|
||||
if c.Linux != nil && !isVersionValid(c.Linux.MinKernelVersion) {
|
||||
return fmt.Errorf("%s linux kernel version: %s is not valid", c.Name(),
|
||||
c.Linux.MinKernelVersion)
|
||||
}
|
||||
|
||||
if c.Windows != nil && !isVersionValid(c.Windows.MinKernelVersion) {
|
||||
return fmt.Errorf("%s windows kernel version: %s is not valid", c.Name(),
|
||||
c.Windows.MinKernelVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
|
||||
if check == nil {
|
||||
log.Debugf("peer %s OS is not allowed in the check", peerGoOS)
|
||||
|
||||
@@ -150,3 +150,79 @@ func TestOSVersionCheck_Check(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOSVersionCheck_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
check OSVersionCheck
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid linux kernel version",
|
||||
check: OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid linux and darwin version",
|
||||
check: OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"},
|
||||
Darwin: &MinVersionCheck{MinVersion: "14.2"},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid empty check",
|
||||
check: OSVersionCheck{},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid empty linux kernel version",
|
||||
check: OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid empty linux kernel version with correct darwin version",
|
||||
check: OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{},
|
||||
Darwin: &MinVersionCheck{MinVersion: "14.2"},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid windows kernel version",
|
||||
check: OSVersionCheck{
|
||||
Windows: &MinKernelVersionCheck{MinKernelVersion: "10.0"},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid ios minimum version",
|
||||
check: OSVersionCheck{
|
||||
Ios: &MinVersionCheck{MinVersion: "13.0"},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid empty window version with valid ios minimum version",
|
||||
check: OSVersionCheck{
|
||||
Windows: &MinKernelVersionCheck{},
|
||||
Ios: &MinVersionCheck{MinVersion: "13.0"},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.check.Validate()
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
79
management/server/posture/process.go
Normal file
79
management/server/posture/process.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
LinuxPath string
|
||||
MacPath string
|
||||
WindowsPath string
|
||||
}
|
||||
|
||||
type ProcessCheck struct {
|
||||
Processes []Process
|
||||
}
|
||||
|
||||
var _ Check = (*ProcessCheck)(nil)
|
||||
|
||||
func (p *ProcessCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
peerActiveProcesses := extractPeerActiveProcesses(peer.Meta.Files)
|
||||
|
||||
var pathSelector func(Process) string
|
||||
switch peer.Meta.GoOS {
|
||||
case "linux":
|
||||
pathSelector = func(process Process) string { return process.LinuxPath }
|
||||
case "darwin":
|
||||
pathSelector = func(process Process) string { return process.MacPath }
|
||||
case "windows":
|
||||
pathSelector = func(process Process) string { return process.WindowsPath }
|
||||
default:
|
||||
return false, fmt.Errorf("unsupported peer's operating system: %s", peer.Meta.GoOS)
|
||||
}
|
||||
|
||||
return p.areAllProcessesRunning(peerActiveProcesses, pathSelector), nil
|
||||
}
|
||||
|
||||
func (p *ProcessCheck) Name() string {
|
||||
return ProcessCheckName
|
||||
}
|
||||
|
||||
func (p *ProcessCheck) Validate() error {
|
||||
if len(p.Processes) == 0 {
|
||||
return fmt.Errorf("%s processes shouldn't be empty", p.Name())
|
||||
}
|
||||
|
||||
for _, process := range p.Processes {
|
||||
if process.LinuxPath == "" && process.MacPath == "" && process.WindowsPath == "" {
|
||||
return fmt.Errorf("%s path shouldn't be empty", p.Name())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// areAllProcessesRunning checks if all processes specified in ProcessCheck are running.
|
||||
// It uses the provided pathSelector to get the appropriate process path for the peer's OS.
|
||||
// It returns true if all processes are running, otherwise false.
|
||||
func (p *ProcessCheck) areAllProcessesRunning(activeProcesses []string, pathSelector func(Process) string) bool {
|
||||
for _, process := range p.Processes {
|
||||
path := pathSelector(process)
|
||||
if path == "" || !slices.Contains(activeProcesses, path) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// extractPeerActiveProcesses extracts the paths of running processes from the peer meta.
|
||||
func extractPeerActiveProcesses(files []nbpeer.File) []string {
|
||||
activeProcesses := make([]string, 0, len(files))
|
||||
for _, file := range files {
|
||||
if file.ProcessIsRunning {
|
||||
activeProcesses = append(activeProcesses, file.Path)
|
||||
}
|
||||
}
|
||||
return activeProcesses
|
||||
}
|
||||
318
management/server/posture/process_test.go
Normal file
318
management/server/posture/process_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestProcessCheck_Check(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input peer.Peer
|
||||
check ProcessCheck
|
||||
wantErr bool
|
||||
isValid bool
|
||||
}{
|
||||
{
|
||||
name: "darwin with matching running processes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "darwin",
|
||||
Files: []peer.File{
|
||||
{Path: "/Applications/process1.app", ProcessIsRunning: true},
|
||||
{Path: "/Applications/process2.app", ProcessIsRunning: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{MacPath: "/Applications/process1.app"},
|
||||
{MacPath: "/Applications/process2.app"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "darwin with windows process paths",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "darwin",
|
||||
Files: []peer.File{
|
||||
{Path: "/Applications/process1.app", ProcessIsRunning: true},
|
||||
{Path: "/Applications/process2.app", ProcessIsRunning: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{WindowsPath: "C:\\Program Files\\process1.exe"},
|
||||
{WindowsPath: "C:\\Program Files\\process2.exe"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "linux with matching running processes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
Files: []peer.File{
|
||||
{Path: "/usr/bin/process1", ProcessIsRunning: true},
|
||||
{Path: "/usr/bin/process2", ProcessIsRunning: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{LinuxPath: "/usr/bin/process1"},
|
||||
{LinuxPath: "/usr/bin/process2"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "linux with matching no running processes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
Files: []peer.File{
|
||||
{Path: "/usr/bin/process1", ProcessIsRunning: true},
|
||||
{Path: "/usr/bin/process2", ProcessIsRunning: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{LinuxPath: "/usr/bin/process1"},
|
||||
{LinuxPath: "/usr/bin/process2"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "linux with windows process paths",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
Files: []peer.File{
|
||||
{Path: "/usr/bin/process1", ProcessIsRunning: true},
|
||||
{Path: "/usr/bin/process2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{WindowsPath: "C:\\Program Files\\process1.exe"},
|
||||
{WindowsPath: "C:\\Program Files\\process2.exe"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "linux with non-matching processes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
Files: []peer.File{
|
||||
{Path: "/usr/bin/process3"},
|
||||
{Path: "/usr/bin/process4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{LinuxPath: "/usr/bin/process1"},
|
||||
{LinuxPath: "/usr/bin/process2"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "windows with matching running processes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
Files: []peer.File{
|
||||
{Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true},
|
||||
{Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{WindowsPath: "C:\\Program Files\\process1.exe"},
|
||||
{WindowsPath: "C:\\Program Files\\process1.exe"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "windows with darwin process paths",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
Files: []peer.File{
|
||||
{Path: "C:\\Program Files\\process1.exe"},
|
||||
{Path: "C:\\Program Files\\process1.exe"},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{MacPath: "/Applications/process1.app"},
|
||||
{LinuxPath: "/Applications/process2.app"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "windows with non-matching processes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
Files: []peer.File{
|
||||
{Path: "C:\\Program Files\\process3.exe"},
|
||||
{Path: "C:\\Program Files\\process4.exe"},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{WindowsPath: "C:\\Program Files\\process1.exe"},
|
||||
{WindowsPath: "C:\\Program Files\\process2.exe"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported ios operating system",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "ios",
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{WindowsPath: "C:\\Program Files\\process1.exe"},
|
||||
{MacPath: "/Applications/process2.app"},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported android operating system",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "android",
|
||||
},
|
||||
},
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{WindowsPath: "C:\\Program Files\\process1.exe"},
|
||||
{MacPath: "/Applications/process2.app"},
|
||||
{LinuxPath: "/usr/bin/process2"},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
isValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid, err := tt.check.Check(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.isValid, isValid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessCheck_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
check ProcessCheck
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid linux, mac and windows processes",
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{
|
||||
LinuxPath: "/usr/local/bin/netbird",
|
||||
MacPath: "/usr/local/bin/netbird",
|
||||
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid linux process",
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{
|
||||
LinuxPath: "/usr/local/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid mac process",
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{
|
||||
MacPath: "/Applications/NetBird.app/Contents/MacOS/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid windows process",
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{
|
||||
{
|
||||
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid empty processes",
|
||||
check: ProcessCheck{
|
||||
Processes: []Process{},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.check.Validate()
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
@@ -21,7 +29,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks")
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
for _, postureChecks := range account.PostureChecks {
|
||||
@@ -48,11 +56,11 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks")
|
||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.BadRequest, err.Error())
|
||||
return status.Errorf(status.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
||||
@@ -95,7 +103,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID,
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks")
|
||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
|
||||
@@ -127,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks")
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
return account.PostureChecks, nil
|
||||
@@ -176,3 +184,74 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh
|
||||
|
||||
return postureChecks, nil
|
||||
}
|
||||
|
||||
// GetPeerAppliedPostureChecks returns posture checks that are applied to the peer.
|
||||
func (am *DefaultAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
|
||||
account, err := am.Store.GetAccountByPeerPubKey(peerKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting peer %s: %v", peerKey, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := account.FindPeerByPubKey(peerKey)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "peer is not registered")
|
||||
}
|
||||
if peer == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
peerPostureChecks := am.collectPeerPostureChecks(account, peer)
|
||||
|
||||
postureChecksList := make([]posture.Checks, 0, len(peerPostureChecks))
|
||||
for _, check := range peerPostureChecks {
|
||||
postureChecksList = append(postureChecksList, check)
|
||||
}
|
||||
|
||||
return postureChecksList, nil
|
||||
}
|
||||
|
||||
// collectPeerPostureChecks collects the posture checks applied for a given peer.
|
||||
func (am *DefaultAccountManager) collectPeerPostureChecks(account *Account, peer *nbpeer.Peer) map[string]posture.Checks {
|
||||
peerPostureChecks := make(map[string]posture.Checks)
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
|
||||
addPolicyPostureChecks(account, policy, peerPostureChecks)
|
||||
}
|
||||
}
|
||||
|
||||
return peerPostureChecks
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group, ok := account.Groups[sourceGroup]
|
||||
if ok && slices.Contains(group.Peers, peerID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if postureCheck.ID == sourcePostureCheckID {
|
||||
peerPostureChecks[sourcePostureCheckID] = *postureCheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -39,10 +40,10 @@ func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, us
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
}
|
||||
|
||||
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error {
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
// routes can have both peer and peer_groups
|
||||
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
|
||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||
seenPeers := make(map[string]bool)
|
||||
@@ -114,7 +115,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
|
||||
}
|
||||
|
||||
// CreateRoute creates and saves a new route
|
||||
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
||||
func (am *DefaultAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -123,6 +124,18 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(domains) > 0 && prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
|
||||
if len(domains) == 0 && !prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix")
|
||||
}
|
||||
|
||||
if len(domains) > 0 {
|
||||
prefix = getPlaceholderIP()
|
||||
}
|
||||
|
||||
if peerID != "" && len(peerGroupIDs) != 0 {
|
||||
return nil, status.Errorf(
|
||||
status.InvalidArgument,
|
||||
@@ -133,11 +146,6 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
||||
var newRoute route.Route
|
||||
newRoute.ID = route.ID(xid.New().String())
|
||||
|
||||
prefixType, newPrefix, err := route.ParseNetwork(network)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network)
|
||||
}
|
||||
|
||||
if len(peerGroupIDs) > 0 {
|
||||
err = validateGroups(peerGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
@@ -145,7 +153,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixExistsForPeers(account, peerID, newRoute.ID, peerGroupIDs, newPrefix)
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -165,14 +173,16 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
||||
|
||||
newRoute.Peer = peerID
|
||||
newRoute.PeerGroups = peerGroupIDs
|
||||
newRoute.Network = newPrefix
|
||||
newRoute.NetworkType = prefixType
|
||||
newRoute.Network = prefix
|
||||
newRoute.Domains = domains
|
||||
newRoute.NetworkType = networkType
|
||||
newRoute.Description = description
|
||||
newRoute.NetID = netID
|
||||
newRoute.Masquerade = masquerade
|
||||
newRoute.Metric = metric
|
||||
newRoute.Enabled = enabled
|
||||
newRoute.Groups = groups
|
||||
newRoute.KeepRoute = keepRoute
|
||||
|
||||
if account.Routes == nil {
|
||||
account.Routes = make(map[route.ID]*route.Route)
|
||||
@@ -201,10 +211,6 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
||||
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||
}
|
||||
|
||||
if !routeToSave.Network.IsValid() {
|
||||
return status.Errorf(status.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
|
||||
}
|
||||
|
||||
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
|
||||
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
}
|
||||
@@ -218,6 +224,18 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
||||
return err
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) == 0 && !routeToSave.Network.IsValid() {
|
||||
return status.Errorf(status.InvalidArgument, "invalid Prefix")
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 {
|
||||
routeToSave.Network = getPlaceholderIP()
|
||||
}
|
||||
|
||||
if routeToSave.Peer != "" && len(routeToSave.PeerGroups) != 0 {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
|
||||
}
|
||||
@@ -229,7 +247,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixExistsForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network)
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -313,10 +331,12 @@ func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,3 +347,9 @@ func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
// getPlaceholderIP returns a placeholder IP address for the route if domains are used
|
||||
func getPlaceholderIP() netip.Prefix {
|
||||
// Using an IP from the documentation range to minimize impact in case older clients try to set a route
|
||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -33,13 +34,18 @@ const (
|
||||
routeGroupHA2 = "routeGroupHA2"
|
||||
routeInvalidGroup1 = "routeInvalidGroup1"
|
||||
userID = "testingUser"
|
||||
existingNetwork = "10.10.10.0/24"
|
||||
existingRouteID = "random-id"
|
||||
)
|
||||
|
||||
var existingNetwork = netip.MustParsePrefix("10.10.10.0/24")
|
||||
var existingDomains = domain.List{"example.com"}
|
||||
|
||||
func TestCreateRoute(t *testing.T) {
|
||||
type input struct {
|
||||
network string
|
||||
network netip.Prefix
|
||||
domains domain.List
|
||||
keepRoute bool
|
||||
networkType route.NetworkType
|
||||
netID route.NetID
|
||||
peerKey string
|
||||
peerGroupIDs []string
|
||||
@@ -59,9 +65,10 @@ func TestCreateRoute(t *testing.T) {
|
||||
expectedRoute *route.Route
|
||||
}{
|
||||
{
|
||||
name: "Happy Path",
|
||||
name: "Happy Path Network",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "happy",
|
||||
peerKey: peer1ID,
|
||||
description: "super",
|
||||
@@ -84,10 +91,41 @@ func TestCreateRoute(t *testing.T) {
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Happy Path Domains",
|
||||
inputArgs: input{
|
||||
domains: domain.List{"domain1", "domain2"},
|
||||
keepRoute: true,
|
||||
networkType: route.DomainNetwork,
|
||||
netID: "happy",
|
||||
peerKey: peer1ID,
|
||||
description: "super",
|
||||
masquerade: false,
|
||||
metric: 9999,
|
||||
enabled: true,
|
||||
groups: []string{routeGroup1},
|
||||
},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedRoute: &route.Route{
|
||||
Network: netip.MustParsePrefix("192.0.2.0/32"),
|
||||
Domains: domain.List{"domain1", "domain2"},
|
||||
NetworkType: route.DomainNetwork,
|
||||
NetID: "happy",
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
KeepRoute: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Happy Path Peer Groups",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "happy",
|
||||
peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
|
||||
description: "super",
|
||||
@@ -111,9 +149,10 @@ func TestCreateRoute(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both peer and peer_groups Provided Should Fail",
|
||||
name: "Both network and domains provided should fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
domains: domain.List{"domain1", "domain2"},
|
||||
netID: "happy",
|
||||
peerKey: peer1ID,
|
||||
peerGroupIDs: []string{routeGroupHA1},
|
||||
@@ -127,16 +166,18 @@ func TestCreateRoute(t *testing.T) {
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Bad Prefix Should Fail",
|
||||
name: "Both peer and peer_groups Provided Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/34",
|
||||
netID: "happy",
|
||||
peerKey: peer1ID,
|
||||
description: "super",
|
||||
masquerade: false,
|
||||
metric: 9999,
|
||||
enabled: true,
|
||||
groups: []string{routeGroup1},
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "happy",
|
||||
peerKey: peer1ID,
|
||||
peerGroupIDs: []string{routeGroupHA1},
|
||||
description: "super",
|
||||
masquerade: false,
|
||||
metric: 9999,
|
||||
enabled: true,
|
||||
groups: []string{routeGroup1},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
@@ -144,7 +185,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Bad Peer Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "happy",
|
||||
peerKey: "notExistingPeer",
|
||||
description: "super",
|
||||
@@ -157,9 +199,10 @@ func TestCreateRoute(t *testing.T) {
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Bad Peer already has this route",
|
||||
name: "Bad Peer already has this network route",
|
||||
inputArgs: input{
|
||||
network: existingNetwork,
|
||||
networkType: route.IPv4Network,
|
||||
netID: "bad",
|
||||
peerKey: peer5ID,
|
||||
description: "super",
|
||||
@@ -173,9 +216,44 @@ func TestCreateRoute(t *testing.T) {
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Bad Peers Group already has this route",
|
||||
name: "Bad Peer already has this domains route",
|
||||
inputArgs: input{
|
||||
domains: existingDomains,
|
||||
networkType: route.DomainNetwork,
|
||||
netID: "bad",
|
||||
peerKey: peer5ID,
|
||||
description: "super",
|
||||
masquerade: false,
|
||||
metric: 9999,
|
||||
enabled: true,
|
||||
groups: []string{routeGroup1},
|
||||
},
|
||||
createInitRoute: true,
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Bad Peers Group already has this network route",
|
||||
inputArgs: input{
|
||||
network: existingNetwork,
|
||||
networkType: route.IPv4Network,
|
||||
netID: "bad",
|
||||
peerGroupIDs: []string{routeGroup1, routeGroup3},
|
||||
description: "super",
|
||||
masquerade: false,
|
||||
metric: 9999,
|
||||
enabled: true,
|
||||
groups: []string{routeGroup1},
|
||||
},
|
||||
createInitRoute: true,
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Bad Peers Group already has this domains route",
|
||||
inputArgs: input{
|
||||
domains: existingDomains,
|
||||
networkType: route.DomainNetwork,
|
||||
netID: "bad",
|
||||
peerGroupIDs: []string{routeGroup1, routeGroup3},
|
||||
description: "super",
|
||||
@@ -191,7 +269,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Empty Peer Should Create",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "happy",
|
||||
peerKey: "",
|
||||
description: "super",
|
||||
@@ -217,7 +296,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Large Metric Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
peerKey: peer1ID,
|
||||
netID: "happy",
|
||||
description: "super",
|
||||
@@ -232,7 +312,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Small Metric Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "happy",
|
||||
peerKey: peer1ID,
|
||||
description: "super",
|
||||
@@ -247,7 +328,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Large NetID Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
peerKey: peer1ID,
|
||||
netID: "12345678901234567890qwertyuiopqwertyuiop1",
|
||||
description: "super",
|
||||
@@ -262,7 +344,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Small NetID Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "",
|
||||
peerKey: peer1ID,
|
||||
description: "",
|
||||
@@ -277,7 +360,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Empty Group List Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "NewId",
|
||||
peerKey: peer1ID,
|
||||
description: "",
|
||||
@@ -292,7 +376,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Empty Group ID string Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "NewId",
|
||||
peerKey: peer1ID,
|
||||
description: "",
|
||||
@@ -307,7 +392,8 @@ func TestCreateRoute(t *testing.T) {
|
||||
{
|
||||
name: "Invalid Group Should Fail",
|
||||
inputArgs: input{
|
||||
network: "192.168.0.0/16",
|
||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
networkType: route.IPv4Network,
|
||||
netID: "NewId",
|
||||
peerKey: peer1ID,
|
||||
description: "",
|
||||
@@ -334,29 +420,14 @@ func TestCreateRoute(t *testing.T) {
|
||||
|
||||
if testCase.createInitRoute {
|
||||
groupAll, errInit := account.GetGroupAll()
|
||||
if errInit != nil {
|
||||
t.Errorf("failed to get group all: %s", errInit)
|
||||
}
|
||||
_, errInit = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4},
|
||||
"", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID)
|
||||
if errInit != nil {
|
||||
t.Errorf("failed to create init route: %s", errInit)
|
||||
}
|
||||
require.NoError(t, errInit)
|
||||
_, errInit = am.CreateRoute(account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
|
||||
require.NoError(t, errInit)
|
||||
_, errInit = am.CreateRoute(account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
|
||||
require.NoError(t, errInit)
|
||||
}
|
||||
|
||||
outRoute, err := am.CreateRoute(
|
||||
account.Id,
|
||||
testCase.inputArgs.network,
|
||||
testCase.inputArgs.peerKey,
|
||||
testCase.inputArgs.peerGroupIDs,
|
||||
testCase.inputArgs.description,
|
||||
testCase.inputArgs.netID,
|
||||
testCase.inputArgs.masquerade,
|
||||
testCase.inputArgs.metric,
|
||||
testCase.inputArgs.groups,
|
||||
testCase.inputArgs.enabled,
|
||||
userID,
|
||||
)
|
||||
outRoute, err := am.CreateRoute(account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
|
||||
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
@@ -379,8 +450,13 @@ func TestSaveRoute(t *testing.T) {
|
||||
validUsedPeer := peer5ID
|
||||
invalidPeer := "nonExisting"
|
||||
validPrefix := netip.MustParsePrefix("192.168.0.0/24")
|
||||
placeholderPrefix := netip.MustParsePrefix("192.0.2.0/32")
|
||||
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
|
||||
validMetric := 1000
|
||||
trueKeepRoute := true
|
||||
falseKeepRoute := false
|
||||
ipv4networkType := route.IPv4Network
|
||||
domainNetworkType := route.DomainNetwork
|
||||
invalidMetric := 99999
|
||||
validNetID := route.NetID("12345678901234567890qw")
|
||||
invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1")
|
||||
@@ -395,6 +471,9 @@ func TestSaveRoute(t *testing.T) {
|
||||
newPeerGroups []string
|
||||
newMetric *int
|
||||
newPrefix *netip.Prefix
|
||||
newDomains domain.List
|
||||
newNetworkType *route.NetworkType
|
||||
newKeepRoute *bool
|
||||
newGroups []string
|
||||
skipCopying bool
|
||||
shouldCreate bool
|
||||
@@ -402,7 +481,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
expectedRoute *route.Route
|
||||
}{
|
||||
{
|
||||
name: "Happy Path",
|
||||
name: "Happy Path Network",
|
||||
existingRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
@@ -434,6 +513,45 @@ func TestSaveRoute(t *testing.T) {
|
||||
Groups: []string{routeGroup2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Happy Path Domains",
|
||||
existingRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.Prefix{},
|
||||
Domains: domain.List{"example.com"},
|
||||
KeepRoute: false,
|
||||
NetID: validNetID,
|
||||
NetworkType: route.DomainNetwork,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
newPeer: &validPeer,
|
||||
newMetric: &validMetric,
|
||||
newPrefix: &netip.Prefix{},
|
||||
newDomains: domain.List{"example.com", "example2.com"},
|
||||
newKeepRoute: &trueKeepRoute,
|
||||
newGroups: []string{routeGroup1},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: placeholderPrefix,
|
||||
Domains: domain.List{"example.com", "example2.com"},
|
||||
KeepRoute: true,
|
||||
NetID: validNetID,
|
||||
NetworkType: route.DomainNetwork,
|
||||
Peer: validPeer,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: validMetric,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Happy Path Peer Groups",
|
||||
existingRoute: &route.Route{
|
||||
@@ -466,6 +584,23 @@ func TestSaveRoute(t *testing.T) {
|
||||
Groups: []string{routeGroup2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both network and domains provided should fail",
|
||||
existingRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: validNetID,
|
||||
NetworkType: route.IPv4Network,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
newPrefix: &validPrefix,
|
||||
newDomains: domain.List{"example.com"},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Both peer and peers_roup Provided Should Fail",
|
||||
existingRoute: &route.Route{
|
||||
@@ -623,7 +758,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
name: "Allow to modify existing route with new peer",
|
||||
existingRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix(existingNetwork),
|
||||
Network: existingNetwork,
|
||||
NetID: validNetID,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1ID,
|
||||
@@ -638,7 +773,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
shouldCreate: true,
|
||||
expectedRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix(existingNetwork),
|
||||
Network: existingNetwork,
|
||||
NetID: validNetID,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: validPeer,
|
||||
@@ -654,7 +789,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
name: "Do not allow to modify existing route with a peer from another route",
|
||||
existingRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix(existingNetwork),
|
||||
Network: existingNetwork,
|
||||
NetID: validNetID,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1ID,
|
||||
@@ -672,7 +807,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
name: "Do not allow to modify existing route with a peers group from another route",
|
||||
existingRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix(existingNetwork),
|
||||
Network: existingNetwork,
|
||||
NetID: validNetID,
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{routeGroup3},
|
||||
@@ -686,6 +821,80 @@ func TestSaveRoute(t *testing.T) {
|
||||
newPeerGroups: []string{routeGroup4},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Allow switching from network route to domains route",
|
||||
existingRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: validPrefix,
|
||||
Domains: nil,
|
||||
KeepRoute: false,
|
||||
NetID: validNetID,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
newPrefix: &netip.Prefix{},
|
||||
newDomains: domain.List{"example.com"},
|
||||
newNetworkType: &domainNetworkType,
|
||||
newKeepRoute: &trueKeepRoute,
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: placeholderPrefix,
|
||||
NetworkType: route.DomainNetwork,
|
||||
Domains: domain.List{"example.com"},
|
||||
KeepRoute: true,
|
||||
NetID: validNetID,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Allow switching from domains route to network route",
|
||||
existingRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: placeholderPrefix,
|
||||
Domains: domain.List{"example.com"},
|
||||
KeepRoute: true,
|
||||
NetID: validNetID,
|
||||
NetworkType: route.DomainNetwork,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
newPrefix: &validPrefix,
|
||||
newDomains: nil,
|
||||
newKeepRoute: &falseKeepRoute,
|
||||
newNetworkType: &ipv4networkType,
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedRoute: &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: validPrefix,
|
||||
NetworkType: route.IPv4Network,
|
||||
KeepRoute: false,
|
||||
Domains: nil,
|
||||
NetID: validNetID,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
@@ -702,7 +911,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
if testCase.createInitRoute {
|
||||
account.Routes["initRoute"] = &route.Route{
|
||||
ID: "initRoute",
|
||||
Network: netip.MustParsePrefix(existingNetwork),
|
||||
Network: existingNetwork,
|
||||
NetID: existingRouteID,
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{routeGroup4},
|
||||
@@ -739,6 +948,16 @@ func TestSaveRoute(t *testing.T) {
|
||||
routeToSave.Network = *testCase.newPrefix
|
||||
}
|
||||
|
||||
routeToSave.Domains = testCase.newDomains
|
||||
|
||||
if testCase.newNetworkType != nil {
|
||||
routeToSave.NetworkType = *testCase.newNetworkType
|
||||
}
|
||||
|
||||
if testCase.newKeepRoute != nil {
|
||||
routeToSave.KeepRoute = *testCase.newKeepRoute
|
||||
}
|
||||
|
||||
if testCase.newGroups != nil {
|
||||
routeToSave.Groups = testCase.newGroups
|
||||
}
|
||||
@@ -771,6 +990,8 @@ func TestDeleteRoute(t *testing.T) {
|
||||
testingRoute := &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
Domains: domain.List{"domain1", "domain2"},
|
||||
KeepRoute: true,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1Key,
|
||||
Description: "super",
|
||||
@@ -839,9 +1060,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||
|
||||
newRoute, err := am.CreateRoute(
|
||||
account.Id, baseRoute.Network.String(), baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description,
|
||||
baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID)
|
||||
newRoute, err := am.CreateRoute(account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newRoute.Enabled, true)
|
||||
|
||||
@@ -932,9 +1151,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||
|
||||
createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network.String(), peer1ID, []string{},
|
||||
baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false,
|
||||
userID)
|
||||
createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
noDisabledRoutes, err := am.GetNetworkMap(peer1ID)
|
||||
|
||||
Reference in New Issue
Block a user