mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-28 21:26:40 +00:00
Compare commits
14 Commits
ci-win-tes
...
feature/di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
279e96e6b1 | ||
|
|
245481f33b | ||
|
|
b352ab84c0 | ||
|
|
3ce5d6a4f8 | ||
|
|
4c2eb2af73 | ||
|
|
daf1449174 | ||
|
|
1ff7abe909 | ||
|
|
067c77e49e | ||
|
|
291e640b28 | ||
|
|
efb954b7d6 | ||
|
|
cac9326d3d | ||
|
|
520d9c66cf | ||
|
|
ff10498a8b | ||
|
|
00b747ad5d |
@@ -4,7 +4,7 @@
|
|||||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||||
|
|
||||||
FROM alpine:3.22.2
|
FROM alpine:3.23.2
|
||||||
# iproute2: busybox doesn't display ip rules properly
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
bash \
|
bash \
|
||||||
|
|||||||
@@ -314,9 +314,8 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
|||||||
profName = activeProf.Name
|
profName = activeProf.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
|
||||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
statusOutputString = overview.FullDetailSummary()
|
||||||
)
|
|
||||||
}
|
}
|
||||||
return statusOutputString
|
return statusOutputString
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
statusOutputString = outputInformationHolder.FullDetailSummary()
|
||||||
case jsonFlag:
|
case jsonFlag:
|
||||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
statusOutputString, err = outputInformationHolder.JSON()
|
||||||
case yamlFlag:
|
case yamlFlag:
|
||||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
statusOutputString, err = outputInformationHolder.YAML()
|
||||||
default:
|
default:
|
||||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -38,6 +39,7 @@ type Client struct {
|
|||||||
setupKey string
|
setupKey string
|
||||||
jwtToken string
|
jwtToken string
|
||||||
connect *internal.ConnectClient
|
connect *internal.ConnectClient
|
||||||
|
recorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options configures a new Client.
|
// Options configures a new Client.
|
||||||
@@ -161,11 +163,17 @@ func New(opts Options) (*Client, error) {
|
|||||||
func (c *Client) Start(startCtx context.Context) error {
|
func (c *Client) Start(startCtx context.Context) error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
if c.cancel != nil {
|
if c.connect != nil {
|
||||||
return ErrClientAlreadyStarted
|
return ErrClientAlreadyStarted
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := internal.CtxInitState(context.Background())
|
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
|
||||||
|
defer func() {
|
||||||
|
if c.connect == nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||||
@@ -173,7 +181,9 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||||
|
c.recorder = recorder
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||||
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
// TODO: make after-startup backoff err available
|
// TODO: make after-startup backoff err available
|
||||||
@@ -197,6 +207,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.connect = client
|
c.connect = client
|
||||||
|
c.cancel = cancel
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -211,17 +222,23 @@ func (c *Client) Stop(ctx context.Context) error {
|
|||||||
return ErrClientNotStarted
|
return ErrClientNotStarted
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.cancel != nil {
|
||||||
|
c.cancel()
|
||||||
|
c.cancel = nil
|
||||||
|
}
|
||||||
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
|
connect := c.connect
|
||||||
go func() {
|
go func() {
|
||||||
done <- c.connect.Stop()
|
done <- connect.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
c.cancel = nil
|
c.connect = nil
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
c.cancel = nil
|
c.connect = nil
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("stop: %w", err)
|
return fmt.Errorf("stop: %w", err)
|
||||||
}
|
}
|
||||||
@@ -315,6 +332,62 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Status returns the current status of the client.
|
||||||
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
recorder := c.recorder
|
||||||
|
connect := c.connect
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if recorder == nil {
|
||||||
|
return peer.FullStatus{}, errors.New("client not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
if connect != nil {
|
||||||
|
engine := connect.Engine()
|
||||||
|
if engine != nil {
|
||||||
|
_ = engine.RunHealthProbes(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return recorder.GetFullStatus(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||||
|
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
syncResp, err := engine.GetLatestSyncResponse()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get sync response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return syncResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the logging level for the client and its components.
|
||||||
|
func (c *Client) SetLogLevel(levelStr string) error {
|
||||||
|
level, err := logrus.ParseLevel(levelStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse log level: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.SetLevel(level)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
connect := c.connect
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if connect != nil {
|
||||||
|
connect.SetLogLevel(level)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||||
|
|||||||
@@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
|
|||||||
return syncResponse, nil
|
return syncResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager if the engine is running.
|
||||||
|
func (c *ConnectClient) SetLogLevel(level log.Level) {
|
||||||
|
engine := c.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fwManager := engine.GetFirewallManager()
|
||||||
|
if fwManager != nil {
|
||||||
|
fwManager.SetLogLevel(level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current client status
|
// Status returns the current client status
|
||||||
func (c *ConnectClient) Status() StatusType {
|
func (c *ConnectClient) Status() StatusType {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
PriorityMgmtCache = 150
|
PriorityMgmtCache = 150
|
||||||
PriorityLocal = 100
|
PriorityDNSRoute = 100
|
||||||
PriorityDNSRoute = 75
|
PriorityLocal = 75
|
||||||
PriorityUpstream = 50
|
PriorityUpstream = 50
|
||||||
PriorityDefault = 1
|
PriorityDefault = 1
|
||||||
PriorityFallback = -100
|
PriorityFallback = -100
|
||||||
|
|||||||
@@ -631,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
|
|
||||||
handler, err := newUpstreamResolver(
|
handler, err := newUpstreamResolver(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
s.wgInterface.Name(),
|
s.wgInterface,
|
||||||
s.wgInterface.Address().IP,
|
|
||||||
s.wgInterface.Address().Network,
|
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
nbdns.RootZone,
|
nbdns.RootZone,
|
||||||
@@ -743,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||||
handler, err := newUpstreamResolver(
|
handler, err := newUpstreamResolver(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
s.wgInterface.Name(),
|
s.wgInterface,
|
||||||
s.wgInterface.Address().IP,
|
|
||||||
s.wgInterface.Address().Network,
|
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
domainGroup.domain,
|
domainGroup.domain,
|
||||||
@@ -926,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
|
|
||||||
handler, err := newUpstreamResolver(
|
handler, err := newUpstreamResolver(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
s.wgInterface.Name(),
|
s.wgInterface,
|
||||||
s.wgInterface.Address().IP,
|
|
||||||
s.wgInterface.Address().Network,
|
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
nbdns.RootZone,
|
nbdns.RootZone,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
|||||||
return configurer.WGStats{}, nil
|
return configurer.WGStats{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetNet() *netstack.Net {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
{
|
{
|
||||||
Name: "peera.netbird.cloud",
|
Name: "peera.netbird.cloud",
|
||||||
@@ -2047,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
|
|||||||
|
|
||||||
func TestLocalResolverPriorityConstants(t *testing.T) {
|
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||||
// Test that priority constants are ordered correctly
|
// Test that priority constants are ordered correctly
|
||||||
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
|
||||||
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||||
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
@@ -418,6 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||||
|
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||||
|
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||||
|
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If response is truncated, retry with TCP
|
||||||
|
if reply != nil && reply.MsgHdr.Truncated {
|
||||||
|
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||||
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||||
|
}
|
||||||
|
|
||||||
|
return reply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
|
||||||
|
conn, err := nsNet.DialContext(ctx, network, upstream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("with %s: %w", network, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("failed to close DNS connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
if err := conn.SetDeadline(deadline); err != nil {
|
||||||
|
return nil, fmt.Errorf("set deadline: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsConn := &dns.Conn{Conn: conn}
|
||||||
|
|
||||||
|
if err := dnsConn.WriteMsg(r); err != nil {
|
||||||
|
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, err := dnsConn.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read %s message: %w", network, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return reply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||||
func FormatPeerStatus(peerState *peer.State) string {
|
func FormatPeerStatus(peerState *peer.State) string {
|
||||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||||
|
|||||||
@@ -23,9 +23,7 @@ type upstreamResolver struct {
|
|||||||
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
|
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
|
||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
_ string,
|
_ WGIface,
|
||||||
_ netip.Addr,
|
|
||||||
_ netip.Prefix,
|
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
|
|||||||
@@ -5,22 +5,23 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
*upstreamResolverBase
|
*upstreamResolverBase
|
||||||
|
nsNet *netstack.Net
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
_ string,
|
wgIface WGIface,
|
||||||
_ netip.Addr,
|
|
||||||
_ netip.Prefix,
|
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
@@ -28,12 +29,23 @@ func newUpstreamResolver(
|
|||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
nonIOS := &upstreamResolver{
|
nonIOS := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
|
nsNet: wgIface.GetNet(),
|
||||||
}
|
}
|
||||||
upstreamResolverBase.upstreamClient = nonIOS
|
upstreamResolverBase.upstreamClient = nonIOS
|
||||||
return nonIOS, nil
|
return nonIOS, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
|
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
|
||||||
|
// Similar to iOS logic, we should determine if the DNS server is reachable directly
|
||||||
|
// or needs to go through the tunnel, and only use netstack when necessary.
|
||||||
|
// For now, only use netstack on JS platform where direct access is not possible.
|
||||||
|
if u.nsNet != nil && runtime.GOOS == "js" {
|
||||||
|
start := time.Now()
|
||||||
|
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
|
||||||
|
return reply, time.Since(start), err
|
||||||
|
}
|
||||||
|
|
||||||
client := &dns.Client{
|
client := &dns.Client{
|
||||||
Timeout: ClientTimeout,
|
Timeout: ClientTimeout,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
|
|||||||
|
|
||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
interfaceName string,
|
wgIface WGIface,
|
||||||
ip netip.Addr,
|
|
||||||
net netip.Prefix,
|
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
@@ -37,9 +35,9 @@ func newUpstreamResolver(
|
|||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
lIP: ip,
|
lIP: wgIface.Address().IP,
|
||||||
lNet: net,
|
lNet: wgIface.Address().Network,
|
||||||
interfaceName: interfaceName,
|
interfaceName: wgIface.Name(),
|
||||||
}
|
}
|
||||||
ios.upstreamClient = ios
|
ios.upstreamClient = ios
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,17 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
|
||||||
// Convert test servers to netip.AddrPort
|
// Convert test servers to netip.AddrPort
|
||||||
var servers []netip.AddrPort
|
var servers []netip.AddrPort
|
||||||
for _, server := range testCase.InputServers {
|
for _, server := range testCase.InputServers {
|
||||||
@@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockNetstackProvider struct{}
|
||||||
|
|
||||||
|
func (m *mockNetstackProvider) Name() string { return "mock" }
|
||||||
|
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
|
||||||
|
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
|
||||||
|
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
|
||||||
|
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
|
||||||
|
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
|
||||||
|
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
|
||||||
|
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
type mockUpstreamResolver struct {
|
type mockUpstreamResolver struct {
|
||||||
r *dns.Msg
|
r *dns.Msg
|
||||||
rtt time.Duration
|
rtt time.Duration
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@@ -17,4 +19,5 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetNet() *netstack.Net
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@@ -12,5 +14,6 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetNet() *netstack.Net
|
||||||
GetInterfaceGUIDString() (string, error)
|
GetInterfaceGUIDString() (string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1748,6 +1748,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
e.syncMsgMux.Unlock()
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Skip STUN/TURN probing for JS/WASM as it's not available
|
||||||
|
relayHealthy := true
|
||||||
|
if runtime.GOOS != "js" {
|
||||||
var results []relay.ProbeResult
|
var results []relay.ProbeResult
|
||||||
if waitForResult {
|
if waitForResult {
|
||||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||||
@@ -1756,7 +1760,6 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
|||||||
}
|
}
|
||||||
e.statusRecorder.UpdateRelayStates(results)
|
e.statusRecorder.UpdateRelayStates(results)
|
||||||
|
|
||||||
relayHealthy := true
|
|
||||||
for _, res := range results {
|
for _, res := range results {
|
||||||
if res.Err != nil {
|
if res.Err != nil {
|
||||||
relayHealthy = false
|
relayHealthy = false
|
||||||
@@ -1764,6 +1767,7 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||||
|
}
|
||||||
|
|
||||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||||
|
|||||||
@@ -72,9 +72,16 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||||
|
audiences := protoJWT.GetAudiences()
|
||||||
|
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||||
|
audiences = []string{protoJWT.GetAudience()}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)
|
||||||
|
|
||||||
jwtConfig := &sshserver.JWTConfig{
|
jwtConfig := &sshserver.JWTConfig{
|
||||||
Issuer: protoJWT.GetIssuer(),
|
Issuer: protoJWT.GetIssuer(),
|
||||||
Audience: protoJWT.GetAudience(),
|
Audiences: audiences,
|
||||||
KeysLocation: protoJWT.GetKeysLocation(),
|
KeysLocation: protoJWT.GetKeysLocation(),
|
||||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@@ -158,6 +159,7 @@ type FullStatus struct {
|
|||||||
NSGroupStates []NSGroupState
|
NSGroupStates []NSGroupState
|
||||||
NumOfForwardingRules int
|
NumOfForwardingRules int
|
||||||
LazyConnectionEnabled bool
|
LazyConnectionEnabled bool
|
||||||
|
Events []*proto.SystemEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
type StatusChangeSubscription struct {
|
type StatusChangeSubscription struct {
|
||||||
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
|
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
|
||||||
|
fullStatus.Events = d.GetEventHistory()
|
||||||
return fullStatus
|
return fullStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1181,3 +1184,97 @@ type EventSubscription struct {
|
|||||||
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
|
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
|
||||||
return s.events
|
return s.events
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToProto converts FullStatus to proto.FullStatus.
|
||||||
|
func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||||
|
pbFullStatus := proto.FullStatus{
|
||||||
|
ManagementState: &proto.ManagementState{},
|
||||||
|
SignalState: &proto.SignalState{},
|
||||||
|
LocalPeerState: &proto.LocalPeerState{},
|
||||||
|
Peers: []*proto.PeerState{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
|
||||||
|
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
|
||||||
|
if err := fs.ManagementState.Error; err != nil {
|
||||||
|
pbFullStatus.ManagementState.Error = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
pbFullStatus.SignalState.URL = fs.SignalState.URL
|
||||||
|
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
|
||||||
|
if err := fs.SignalState.Error; err != nil {
|
||||||
|
pbFullStatus.SignalState.Error = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
|
||||||
|
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
|
||||||
|
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
|
||||||
|
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
|
||||||
|
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
|
||||||
|
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
|
||||||
|
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
|
||||||
|
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
|
||||||
|
|
||||||
|
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
|
||||||
|
|
||||||
|
for _, peerState := range fs.Peers {
|
||||||
|
networks := maps.Keys(peerState.GetRoutes())
|
||||||
|
|
||||||
|
pbPeerState := &proto.PeerState{
|
||||||
|
IP: peerState.IP,
|
||||||
|
PubKey: peerState.PubKey,
|
||||||
|
ConnStatus: peerState.ConnStatus.String(),
|
||||||
|
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||||
|
Relayed: peerState.Relayed,
|
||||||
|
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||||
|
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||||
|
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||||
|
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||||
|
RelayAddress: peerState.RelayServerAddress,
|
||||||
|
Fqdn: peerState.FQDN,
|
||||||
|
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||||
|
BytesRx: peerState.BytesRx,
|
||||||
|
BytesTx: peerState.BytesTx,
|
||||||
|
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||||
|
Networks: networks,
|
||||||
|
Latency: durationpb.New(peerState.Latency),
|
||||||
|
SshHostKey: peerState.SSHHostKey,
|
||||||
|
}
|
||||||
|
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, relayState := range fs.Relays {
|
||||||
|
pbRelayState := &proto.RelayState{
|
||||||
|
URI: relayState.URI,
|
||||||
|
Available: relayState.Err == nil,
|
||||||
|
}
|
||||||
|
if err := relayState.Err; err != nil {
|
||||||
|
pbRelayState.Error = err.Error()
|
||||||
|
}
|
||||||
|
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dnsState := range fs.NSGroupStates {
|
||||||
|
var err string
|
||||||
|
if dnsState.Error != nil {
|
||||||
|
err = dnsState.Error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
var servers []string
|
||||||
|
for _, server := range dnsState.Servers {
|
||||||
|
servers = append(servers, server.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
pbDnsState := &proto.NSGroupState{
|
||||||
|
Servers: servers,
|
||||||
|
Domains: dnsState.Domains,
|
||||||
|
Enabled: dnsState.Enabled,
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||||
|
}
|
||||||
|
|
||||||
|
pbFullStatus.Events = fs.Events
|
||||||
|
|
||||||
|
return &pbFullStatus
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,13 +17,13 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
|
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
@@ -38,11 +38,6 @@ type internalDNATer interface {
|
|||||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type wgInterface interface {
|
|
||||||
Name() string
|
|
||||||
Address() wgaddr.Address
|
|
||||||
}
|
|
||||||
|
|
||||||
type DnsInterceptor struct {
|
type DnsInterceptor struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
route *route.Route
|
route *route.Route
|
||||||
@@ -52,7 +47,7 @@ type DnsInterceptor struct {
|
|||||||
dnsServer nbdns.Server
|
dnsServer nbdns.Server
|
||||||
currentPeerKey string
|
currentPeerKey string
|
||||||
interceptedDomains domainMap
|
interceptedDomains domainMap
|
||||||
wgInterface wgInterface
|
wgInterface iface.WGIface
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
fakeIPManager *fakeip.Manager
|
fakeIPManager *fakeip.Manager
|
||||||
@@ -250,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
|
||||||
if err != nil {
|
|
||||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
@@ -264,20 +253,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
startTime := time.Now()
|
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
|
||||||
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
if reply == nil {
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
elapsed := time.Since(startTime)
|
|
||||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
|
||||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
|
||||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
|
||||||
} else {
|
|
||||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
|
||||||
}
|
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
|
||||||
logger.Errorf("failed writing DNS response: %v", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
|
||||||
|
// Returns the DNS reply on success, or nil on error (error responses are written internally).
|
||||||
|
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
nsNet := d.wgInterface.GetNet()
|
||||||
|
var reply *dns.Msg
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if nsNet != nil {
|
||||||
|
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
|
||||||
|
} else {
|
||||||
|
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||||
|
if clientErr != nil {
|
||||||
|
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return reply
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
elapsed := time.Since(startTime)
|
||||||
|
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||||
|
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||||
|
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||||
|
} else {
|
||||||
|
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||||
|
}
|
||||||
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
|
logger.Errorf("failed writing DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
|
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
|
||||||
if d.statusRecorder == nil {
|
if d.statusRecorder == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetNet() *netstack.Net
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,20 +173,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
|||||||
|
|
||||||
log.SetLevel(level)
|
log.SetLevel(level)
|
||||||
|
|
||||||
if s.connectClient == nil {
|
if s.connectClient != nil {
|
||||||
return nil, fmt.Errorf("connect client not initialized")
|
s.connectClient.SetLogLevel(level)
|
||||||
}
|
}
|
||||||
engine := s.connectClient.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
return nil, fmt.Errorf("engine not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
fwManager := engine.GetFirewallManager()
|
|
||||||
if fwManager == nil {
|
|
||||||
return nil, fmt.Errorf("firewall manager not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
fwManager.SetLogLevel(level)
|
|
||||||
|
|
||||||
log.Infof("Log level set to %s", level.String())
|
log.Infof("Log level set to %s", level.String())
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
|
|
||||||
events := s.statusRecorder.GetEventHistory()
|
|
||||||
return &proto.GetEventsResponse{Events: events}, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,15 +13,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
@@ -1067,11 +1064,9 @@ func (s *Server) Status(
|
|||||||
if msg.GetFullPeerStatus {
|
if msg.GetFullPeerStatus {
|
||||||
s.runProbes(msg.ShouldRunProbes)
|
s.runProbes(msg.ShouldRunProbes)
|
||||||
fullStatus := s.statusRecorder.GetFullStatus()
|
fullStatus := s.statusRecorder.GetFullStatus()
|
||||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
pbFullStatus := fullStatus.ToProto()
|
||||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||||
|
|
||||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||||
|
|
||||||
statusResponse.FullStatus = pbFullStatus
|
statusResponse.FullStatus = pbFullStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1600,94 +1595,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
|||||||
return defaultDuration
|
return defaultDuration
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
|
||||||
pbFullStatus := proto.FullStatus{
|
|
||||||
ManagementState: &proto.ManagementState{},
|
|
||||||
SignalState: &proto.SignalState{},
|
|
||||||
LocalPeerState: &proto.LocalPeerState{},
|
|
||||||
Peers: []*proto.PeerState{},
|
|
||||||
}
|
|
||||||
|
|
||||||
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
|
||||||
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
|
||||||
if err := fullStatus.ManagementState.Error; err != nil {
|
|
||||||
pbFullStatus.ManagementState.Error = err.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
|
||||||
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
|
||||||
if err := fullStatus.SignalState.Error; err != nil {
|
|
||||||
pbFullStatus.SignalState.Error = err.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
|
||||||
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
|
||||||
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
|
||||||
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
|
||||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
|
||||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
|
||||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
|
||||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
|
||||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
|
||||||
|
|
||||||
for _, peerState := range fullStatus.Peers {
|
|
||||||
pbPeerState := &proto.PeerState{
|
|
||||||
IP: peerState.IP,
|
|
||||||
PubKey: peerState.PubKey,
|
|
||||||
ConnStatus: peerState.ConnStatus.String(),
|
|
||||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
|
||||||
Relayed: peerState.Relayed,
|
|
||||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
|
||||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
|
||||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
|
||||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
|
||||||
RelayAddress: peerState.RelayServerAddress,
|
|
||||||
Fqdn: peerState.FQDN,
|
|
||||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
|
||||||
BytesRx: peerState.BytesRx,
|
|
||||||
BytesTx: peerState.BytesTx,
|
|
||||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
|
||||||
Networks: maps.Keys(peerState.GetRoutes()),
|
|
||||||
Latency: durationpb.New(peerState.Latency),
|
|
||||||
SshHostKey: peerState.SSHHostKey,
|
|
||||||
}
|
|
||||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, relayState := range fullStatus.Relays {
|
|
||||||
pbRelayState := &proto.RelayState{
|
|
||||||
URI: relayState.URI,
|
|
||||||
Available: relayState.Err == nil,
|
|
||||||
}
|
|
||||||
if err := relayState.Err; err != nil {
|
|
||||||
pbRelayState.Error = err.Error()
|
|
||||||
}
|
|
||||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dnsState := range fullStatus.NSGroupStates {
|
|
||||||
var err string
|
|
||||||
if dnsState.Error != nil {
|
|
||||||
err = dnsState.Error.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
var servers []string
|
|
||||||
for _, server := range dnsState.Servers {
|
|
||||||
servers = append(servers, server.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
pbDnsState := &proto.NSGroupState{
|
|
||||||
Servers: servers,
|
|
||||||
Domains: dnsState.Domains,
|
|
||||||
Enabled: dnsState.Enabled,
|
|
||||||
Error: err,
|
|
||||||
}
|
|
||||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &pbFullStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendTerminalNotification sends a terminal notification message
|
// sendTerminalNotification sends a terminal notification message
|
||||||
// to inform the user that the NetBird connection session has expired.
|
// to inform the user that the NetBird connection session has expired.
|
||||||
func sendTerminalNotification() error {
|
func sendTerminalNotification() error {
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ func TestSSHProxy_Connect(t *testing.T) {
|
|||||||
HostKeyPEM: hostKey,
|
HostKeyPEM: hostKey,
|
||||||
JWT: &server.JWTConfig{
|
JWT: &server.JWTConfig{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
Audience: audience,
|
Audiences: []string{audience},
|
||||||
KeysLocation: jwksURL,
|
KeysLocation: jwksURL,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func TestJWTEnforcement(t *testing.T) {
|
|||||||
t.Run("blocks_without_jwt", func(t *testing.T) {
|
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||||
jwtConfig := &JWTConfig{
|
jwtConfig := &JWTConfig{
|
||||||
Issuer: "test-issuer",
|
Issuer: "test-issuer",
|
||||||
Audience: "test-audience",
|
Audiences: []string{"test-audience"},
|
||||||
KeysLocation: "test-keys",
|
KeysLocation: "test-keys",
|
||||||
}
|
}
|
||||||
serverConfig := &Config{
|
serverConfig := &Config{
|
||||||
@@ -202,7 +202,7 @@ func TestJWTDetection(t *testing.T) {
|
|||||||
|
|
||||||
jwtConfig := &JWTConfig{
|
jwtConfig := &JWTConfig{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
Audience: audience,
|
Audiences: []string{audience},
|
||||||
KeysLocation: jwksURL,
|
KeysLocation: jwksURL,
|
||||||
}
|
}
|
||||||
serverConfig := &Config{
|
serverConfig := &Config{
|
||||||
@@ -329,7 +329,7 @@ func TestJWTFailClose(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
jwtConfig := &JWTConfig{
|
jwtConfig := &JWTConfig{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
Audience: audience,
|
Audiences: []string{audience},
|
||||||
KeysLocation: jwksURL,
|
KeysLocation: jwksURL,
|
||||||
MaxTokenAge: 3600,
|
MaxTokenAge: 3600,
|
||||||
}
|
}
|
||||||
@@ -567,7 +567,7 @@ func TestJWTAuthentication(t *testing.T) {
|
|||||||
|
|
||||||
jwtConfig := &JWTConfig{
|
jwtConfig := &JWTConfig{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
Audience: audience,
|
Audiences: []string{audience},
|
||||||
KeysLocation: jwksURL,
|
KeysLocation: jwksURL,
|
||||||
}
|
}
|
||||||
serverConfig := &Config{
|
serverConfig := &Config{
|
||||||
@@ -646,3 +646,108 @@ func TestJWTAuthentication(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestJWTMultipleAudiences tests JWT validation with multiple audiences (dashboard and CLI).
|
||||||
|
func TestJWTMultipleAudiences(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping JWT multiple audiences tests in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||||
|
defer jwksServer.Close()
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
dashboardAudience = "dashboard-audience"
|
||||||
|
cliAudience = "cli-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
audience string
|
||||||
|
wantAuthOK bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "accepts_dashboard_audience",
|
||||||
|
audience: dashboardAudience,
|
||||||
|
wantAuthOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "accepts_cli_audience",
|
||||||
|
audience: cliAudience,
|
||||||
|
wantAuthOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rejects_unknown_audience",
|
||||||
|
audience: "unknown-audience",
|
||||||
|
wantAuthOK: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
jwtConfig := &JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audiences: []string{dashboardAudience, cliAudience},
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
}
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: jwtConfig,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
testUserHash, err := sshuserhash.HashUserID("test-user")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
currentUser: {0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer require.NoError(t, server.Stop())
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := generateValidJWT(t, privateKey, issuer, tc.audience)
|
||||||
|
config := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: []cryptossh.AuthMethod{
|
||||||
|
cryptossh.Password(token),
|
||||||
|
},
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||||
|
if tc.wantAuthOK {
|
||||||
|
require.NoError(t, err, "JWT authentication should succeed for audience %s", tc.audience)
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
t.Logf("close connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
err = session.Shell()
|
||||||
|
require.NoError(t, err, "Shell should work with valid audience")
|
||||||
|
} else {
|
||||||
|
assert.Error(t, err, "JWT authentication should fail for unknown audience")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -176,9 +176,9 @@ type Server struct {
|
|||||||
|
|
||||||
type JWTConfig struct {
|
type JWTConfig struct {
|
||||||
Issuer string
|
Issuer string
|
||||||
Audience string
|
|
||||||
KeysLocation string
|
KeysLocation string
|
||||||
MaxTokenAge int64
|
MaxTokenAge int64
|
||||||
|
Audiences []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config contains all SSH server configuration options
|
// Config contains all SSH server configuration options
|
||||||
@@ -427,18 +427,21 @@ func (s *Server) ensureJWTValidator() error {
|
|||||||
return fmt.Errorf("JWT config not set")
|
return fmt.Errorf("JWT config not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
if len(config.Audiences) == 0 {
|
||||||
|
return fmt.Errorf("JWT config has no audiences configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Initializing JWT validator (issuer: %s, audiences: %v)", config.Issuer, config.Audiences)
|
||||||
validator := jwt.NewValidator(
|
validator := jwt.NewValidator(
|
||||||
config.Issuer,
|
config.Issuer,
|
||||||
[]string{config.Audience},
|
config.Audiences,
|
||||||
config.KeysLocation,
|
config.KeysLocation,
|
||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Use custom userIDClaim from authorizer if available
|
// Use custom userIDClaim from authorizer if available
|
||||||
extractorOptions := []jwt.ClaimsExtractorOption{
|
extractorOptions := []jwt.ClaimsExtractorOption{
|
||||||
jwt.WithAudience(config.Audience),
|
jwt.WithAudience(config.Audiences[0]),
|
||||||
}
|
}
|
||||||
if authorizer.GetUserIDClaim() != "" {
|
if authorizer.GetUserIDClaim() != "" {
|
||||||
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
|
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
|
||||||
@@ -475,8 +478,8 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if jwtConfig != nil {
|
if jwtConfig != nil {
|
||||||
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
return nil, fmt.Errorf("validate token (expected issuer=%s, audiences=%v, actual issuer=%v, audience=%v): %w",
|
||||||
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
jwtConfig.Issuer, jwtConfig.Audiences, claims["iss"], claims["aud"], err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("validate token: %w", err)
|
return nil, fmt.Errorf("validate token: %w", err)
|
||||||
|
|||||||
@@ -325,61 +325,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseToJSON(overview OutputOverview) (string, error) {
|
// JSON returns the status overview as a JSON string.
|
||||||
jsonBytes, err := json.Marshal(overview)
|
func (o *OutputOverview) JSON() (string, error) {
|
||||||
|
jsonBytes, err := json.Marshal(o)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("json marshal failed")
|
return "", fmt.Errorf("json marshal failed")
|
||||||
}
|
}
|
||||||
return string(jsonBytes), err
|
return string(jsonBytes), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseToYAML(overview OutputOverview) (string, error) {
|
// YAML returns the status overview as a YAML string.
|
||||||
yamlBytes, err := yaml.Marshal(overview)
|
func (o *OutputOverview) YAML() (string, error) {
|
||||||
|
yamlBytes, err := yaml.Marshal(o)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("yaml marshal failed")
|
return "", fmt.Errorf("yaml marshal failed")
|
||||||
}
|
}
|
||||||
return string(yamlBytes), nil
|
return string(yamlBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
// GeneralSummary returns a general summary of the status overview.
|
||||||
|
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||||
var managementConnString string
|
var managementConnString string
|
||||||
if overview.ManagementState.Connected {
|
if o.ManagementState.Connected {
|
||||||
managementConnString = "Connected"
|
managementConnString = "Connected"
|
||||||
if showURL {
|
if showURL {
|
||||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
managementConnString = "Disconnected"
|
managementConnString = "Disconnected"
|
||||||
if overview.ManagementState.Error != "" {
|
if o.ManagementState.Error != "" {
|
||||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var signalConnString string
|
var signalConnString string
|
||||||
if overview.SignalState.Connected {
|
if o.SignalState.Connected {
|
||||||
signalConnString = "Connected"
|
signalConnString = "Connected"
|
||||||
if showURL {
|
if showURL {
|
||||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
signalConnString = "Disconnected"
|
signalConnString = "Disconnected"
|
||||||
if overview.SignalState.Error != "" {
|
if o.SignalState.Error != "" {
|
||||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
interfaceTypeString := "Userspace"
|
interfaceTypeString := "Userspace"
|
||||||
interfaceIP := overview.IP
|
interfaceIP := o.IP
|
||||||
if overview.KernelInterface {
|
if o.KernelInterface {
|
||||||
interfaceTypeString = "Kernel"
|
interfaceTypeString = "Kernel"
|
||||||
} else if overview.IP == "" {
|
} else if o.IP == "" {
|
||||||
interfaceTypeString = "N/A"
|
interfaceTypeString = "N/A"
|
||||||
interfaceIP = "N/A"
|
interfaceIP = "N/A"
|
||||||
}
|
}
|
||||||
|
|
||||||
var relaysString string
|
var relaysString string
|
||||||
if showRelays {
|
if showRelays {
|
||||||
for _, relay := range overview.Relays.Details {
|
for _, relay := range o.Relays.Details {
|
||||||
available := "Available"
|
available := "Available"
|
||||||
reason := ""
|
reason := ""
|
||||||
|
|
||||||
@@ -395,18 +398,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
|
||||||
}
|
}
|
||||||
|
|
||||||
networks := "-"
|
networks := "-"
|
||||||
if len(overview.Networks) > 0 {
|
if len(o.Networks) > 0 {
|
||||||
sort.Strings(overview.Networks)
|
sort.Strings(o.Networks)
|
||||||
networks = strings.Join(overview.Networks, ", ")
|
networks = strings.Join(o.Networks, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
var dnsServersString string
|
var dnsServersString string
|
||||||
if showNameServers {
|
if showNameServers {
|
||||||
for _, nsServerGroup := range overview.NSServerGroups {
|
for _, nsServerGroup := range o.NSServerGroups {
|
||||||
enabled := "Available"
|
enabled := "Available"
|
||||||
if !nsServerGroup.Enabled {
|
if !nsServerGroup.Enabled {
|
||||||
enabled = "Unavailable"
|
enabled = "Unavailable"
|
||||||
@@ -430,25 +433,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
|
||||||
}
|
}
|
||||||
|
|
||||||
rosenpassEnabledStatus := "false"
|
rosenpassEnabledStatus := "false"
|
||||||
if overview.RosenpassEnabled {
|
if o.RosenpassEnabled {
|
||||||
rosenpassEnabledStatus = "true"
|
rosenpassEnabledStatus = "true"
|
||||||
if overview.RosenpassPermissive {
|
if o.RosenpassPermissive {
|
||||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lazyConnectionEnabledStatus := "false"
|
lazyConnectionEnabledStatus := "false"
|
||||||
if overview.LazyConnectionEnabled {
|
if o.LazyConnectionEnabled {
|
||||||
lazyConnectionEnabledStatus = "true"
|
lazyConnectionEnabledStatus = "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
sshServerStatus := "Disabled"
|
sshServerStatus := "Disabled"
|
||||||
if overview.SSHServerState.Enabled {
|
if o.SSHServerState.Enabled {
|
||||||
sessionCount := len(overview.SSHServerState.Sessions)
|
sessionCount := len(o.SSHServerState.Sessions)
|
||||||
if sessionCount > 0 {
|
if sessionCount > 0 {
|
||||||
sessionWord := "session"
|
sessionWord := "session"
|
||||||
if sessionCount > 1 {
|
if sessionCount > 1 {
|
||||||
@@ -460,7 +463,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if showSSHSessions && sessionCount > 0 {
|
if showSSHSessions && sessionCount > 0 {
|
||||||
for _, session := range overview.SSHServerState.Sessions {
|
for _, session := range o.SSHServerState.Sessions {
|
||||||
var sessionDisplay string
|
var sessionDisplay string
|
||||||
if session.JWTUsername != "" {
|
if session.JWTUsername != "" {
|
||||||
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
|
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
|
||||||
@@ -484,7 +487,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||||
|
|
||||||
goos := runtime.GOOS
|
goos := runtime.GOOS
|
||||||
goarch := runtime.GOARCH
|
goarch := runtime.GOARCH
|
||||||
@@ -512,30 +515,31 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
"Forwarding rules: %d\n"+
|
"Forwarding rules: %d\n"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||||
overview.DaemonVersion,
|
o.DaemonVersion,
|
||||||
version.NetbirdVersion(),
|
version.NetbirdVersion(),
|
||||||
overview.ProfileName,
|
o.ProfileName,
|
||||||
managementConnString,
|
managementConnString,
|
||||||
signalConnString,
|
signalConnString,
|
||||||
relaysString,
|
relaysString,
|
||||||
dnsServersString,
|
dnsServersString,
|
||||||
domain.Domain(overview.FQDN).SafeString(),
|
domain.Domain(o.FQDN).SafeString(),
|
||||||
interfaceIP,
|
interfaceIP,
|
||||||
interfaceTypeString,
|
interfaceTypeString,
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
lazyConnectionEnabledStatus,
|
lazyConnectionEnabledStatus,
|
||||||
sshServerStatus,
|
sshServerStatus,
|
||||||
networks,
|
networks,
|
||||||
overview.NumberOfForwardingRules,
|
o.NumberOfForwardingRules,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
)
|
)
|
||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseToFullDetailSummary(overview OutputOverview) string {
|
// FullDetailSummary returns a full detailed summary with peer details and events.
|
||||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
func (o *OutputOverview) FullDetailSummary() string {
|
||||||
parsedEventsString := parseEvents(overview.Events)
|
parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
|
||||||
summary := ParseGeneralSummary(overview, true, true, true, true)
|
parsedEventsString := parseEvents(o.Events)
|
||||||
|
summary := o.GeneralSummary(true, true, true, true)
|
||||||
|
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"Peers detail:"+
|
"Peers detail:"+
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParsingToJSON(t *testing.T) {
|
func TestParsingToJSON(t *testing.T) {
|
||||||
jsonString, _ := ParseToJSON(overview)
|
jsonString, _ := overview.JSON()
|
||||||
|
|
||||||
//@formatter:off
|
//@formatter:off
|
||||||
expectedJSONString := `
|
expectedJSONString := `
|
||||||
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParsingToYAML(t *testing.T) {
|
func TestParsingToYAML(t *testing.T) {
|
||||||
yaml, _ := ParseToYAML(overview)
|
yaml, _ := overview.YAML()
|
||||||
|
|
||||||
expectedYAML :=
|
expectedYAML :=
|
||||||
`peers:
|
`peers:
|
||||||
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
|
|||||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||||
|
|
||||||
detail := ParseToFullDetailSummary(overview)
|
detail := overview.FullDetailSummary()
|
||||||
|
|
||||||
expectedDetail := fmt.Sprintf(
|
expectedDetail := fmt.Sprintf(
|
||||||
`Peers detail:
|
`Peers detail:
|
||||||
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParsingToShortVersion(t *testing.T) {
|
func TestParsingToShortVersion(t *testing.T) {
|
||||||
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
|
shortVersion := overview.GeneralSummary(false, false, false, false)
|
||||||
|
|
||||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||||
Daemon version: 0.14.1
|
Daemon version: 0.14.1
|
||||||
|
|||||||
22
client/system/disk_encryption.go
Normal file
22
client/system/disk_encryption.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package system
|
||||||
|
|
||||||
|
// DiskEncryptionVolume represents encryption status of a single volume.
|
||||||
|
type DiskEncryptionVolume struct {
|
||||||
|
Path string
|
||||||
|
Encrypted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DiskEncryptionInfo holds disk encryption detection results.
|
||||||
|
type DiskEncryptionInfo struct {
|
||||||
|
Volumes []DiskEncryptionVolume
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEncrypted returns true if the volume at the given path is encrypted.
|
||||||
|
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
|
||||||
|
for _, v := range d.Volumes {
|
||||||
|
if v.Path == path {
|
||||||
|
return v.Encrypted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
35
client/system/disk_encryption_darwin.go
Normal file
35
client/system/disk_encryption_darwin.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// detectDiskEncryption detects FileVault encryption status on macOS.
|
||||||
|
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
|
||||||
|
info := DiskEncryptionInfo{}
|
||||||
|
|
||||||
|
cmdCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(cmdCtx, "fdesetup", "status")
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("execute fdesetup: %v", err)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := strings.Contains(string(output), "FileVault is On")
|
||||||
|
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
|
||||||
|
Path: "/",
|
||||||
|
Encrypted: encrypted,
|
||||||
|
})
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
98
client/system/disk_encryption_linux.go
Normal file
98
client/system/disk_encryption_linux.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// detectDiskEncryption detects LUKS encryption status on Linux by reading sysfs.
|
||||||
|
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
|
||||||
|
info := DiskEncryptionInfo{}
|
||||||
|
|
||||||
|
encryptedDevices := findEncryptedDevices()
|
||||||
|
mountPoints := parseMounts(encryptedDevices)
|
||||||
|
|
||||||
|
info.Volumes = mountPoints
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// findEncryptedDevices scans /sys/block for dm-crypt (LUKS) encrypted devices.
|
||||||
|
func findEncryptedDevices() map[string]bool {
|
||||||
|
encryptedDevices := make(map[string]bool)
|
||||||
|
|
||||||
|
sysBlock := "/sys/block"
|
||||||
|
entries, err := os.ReadDir(sysBlock)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("read /sys/block: %v", err)
|
||||||
|
return encryptedDevices
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
dmUuidPath := filepath.Join(sysBlock, entry.Name(), "dm", "uuid")
|
||||||
|
data, err := os.ReadFile(dmUuidPath)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
uuid := strings.TrimSpace(string(data))
|
||||||
|
if strings.HasPrefix(uuid, "CRYPT-") {
|
||||||
|
dmNamePath := filepath.Join(sysBlock, entry.Name(), "dm", "name")
|
||||||
|
if nameData, err := os.ReadFile(dmNamePath); err == nil {
|
||||||
|
dmName := strings.TrimSpace(string(nameData))
|
||||||
|
encryptedDevices["/dev/mapper/"+dmName] = true
|
||||||
|
}
|
||||||
|
encryptedDevices["/dev/"+entry.Name()] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return encryptedDevices
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseMounts reads /proc/mounts and maps devices to mount points with encryption status.
|
||||||
|
func parseMounts(encryptedDevices map[string]bool) []DiskEncryptionVolume {
|
||||||
|
var volumes []DiskEncryptionVolume
|
||||||
|
|
||||||
|
mountsFile, err := os.Open("/proc/mounts")
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("open /proc/mounts: %v", err)
|
||||||
|
return volumes
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := mountsFile.Close(); err != nil {
|
||||||
|
log.Debugf("close /proc/mounts: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(mountsFile)
|
||||||
|
for scanner.Scan() {
|
||||||
|
fields := strings.Fields(scanner.Text())
|
||||||
|
if len(fields) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
device, mountPoint := fields[0], fields[1]
|
||||||
|
|
||||||
|
encrypted := encryptedDevices[device]
|
||||||
|
|
||||||
|
if !encrypted && strings.HasPrefix(device, "/dev/mapper/") {
|
||||||
|
for encDev := range encryptedDevices {
|
||||||
|
if device == encDev {
|
||||||
|
encrypted = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
volumes = append(volumes, DiskEncryptionVolume{
|
||||||
|
Path: mountPoint,
|
||||||
|
Encrypted: encrypted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return volumes
|
||||||
|
}
|
||||||
10
client/system/disk_encryption_stub.go
Normal file
10
client/system/disk_encryption_stub.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build android || ios || freebsd || js
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// detectDiskEncryption is a stub for unsupported platforms.
|
||||||
|
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
|
||||||
|
return DiskEncryptionInfo{}
|
||||||
|
}
|
||||||
41
client/system/disk_encryption_windows.go
Normal file
41
client/system/disk_encryption_windows.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/yusufpapurcu/wmi"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Win32EncryptableVolume represents the WMI class for BitLocker status.
|
||||||
|
type Win32EncryptableVolume struct {
|
||||||
|
DriveLetter string
|
||||||
|
ProtectionStatus uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectDiskEncryption detects BitLocker encryption status on Windows via WMI.
|
||||||
|
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
|
||||||
|
info := DiskEncryptionInfo{}
|
||||||
|
|
||||||
|
var volumes []Win32EncryptableVolume
|
||||||
|
query := "SELECT DriveLetter, ProtectionStatus FROM Win32_EncryptableVolume"
|
||||||
|
|
||||||
|
err := wmi.QueryNamespace(query, &volumes, `root\CIMV2\Security\MicrosoftVolumeEncryption`)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("query BitLocker status: %v", err)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, vol := range volumes {
|
||||||
|
driveLetter := strings.TrimSuffix(vol.DriveLetter, "\\")
|
||||||
|
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
|
||||||
|
Path: driveLetter,
|
||||||
|
Encrypted: vol.ProtectionStatus == 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
@@ -59,6 +59,7 @@ type Info struct {
|
|||||||
SystemManufacturer string
|
SystemManufacturer string
|
||||||
Environment Environment
|
Environment Environment
|
||||||
Files []File // for posture checks
|
Files []File // for posture checks
|
||||||
|
DiskEncryption DiskEncryptionInfo
|
||||||
|
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
SystemSerialNumber: serial(),
|
SystemSerialNumber: serial(),
|
||||||
SystemProductName: productModel(),
|
SystemProductName: productModel(),
|
||||||
SystemManufacturer: productManufacturer(),
|
SystemManufacturer: productManufacturer(),
|
||||||
|
DiskEncryption: detectDiskEncryption(ctx),
|
||||||
}
|
}
|
||||||
|
|
||||||
return gio
|
return gio
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
SystemProductName: si.SystemProductName,
|
SystemProductName: si.SystemProductName,
|
||||||
SystemManufacturer: si.SystemManufacturer,
|
SystemManufacturer: si.SystemManufacturer,
|
||||||
Environment: si.Environment,
|
Environment: si.Environment,
|
||||||
|
DiskEncryption: detectDiskEncryption(ctx),
|
||||||
}
|
}
|
||||||
|
|
||||||
systemHostname, _ := os.Hostname()
|
systemHostname, _ := os.Hostname()
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
UIVersion: extractUserAgent(ctx),
|
UIVersion: extractUserAgent(ctx),
|
||||||
KernelVersion: osInfo[1],
|
KernelVersion: osInfo[1],
|
||||||
Environment: env,
|
Environment: env,
|
||||||
|
DiskEncryption: detectDiskEncryption(ctx),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
sysName := extractOsName(ctx, "sysName")
|
sysName := extractOsName(ctx, "sysName")
|
||||||
swVersion := extractOsVersion(ctx, "swVersion")
|
swVersion := extractOsVersion(ctx, "swVersion")
|
||||||
|
|
||||||
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion}
|
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion, DiskEncryption: detectDiskEncryption(ctx)}
|
||||||
gio.Hostname = extractDeviceName(ctx, "hostname")
|
gio.Hostname = extractDeviceName(ctx, "hostname")
|
||||||
gio.NetbirdVersion = version.NetbirdVersion()
|
gio.NetbirdVersion = version.NetbirdVersion()
|
||||||
gio.UIVersion = extractUserAgent(ctx)
|
gio.UIVersion = extractUserAgent(ctx)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ func UpdateStaticInfoAsync() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetInfo retrieves system information for WASM environment
|
// GetInfo retrieves system information for WASM environment
|
||||||
func GetInfo(_ context.Context) *Info {
|
func GetInfo(ctx context.Context) *Info {
|
||||||
info := &Info{
|
info := &Info{
|
||||||
GoOS: runtime.GOOS,
|
GoOS: runtime.GOOS,
|
||||||
Kernel: runtime.GOARCH,
|
Kernel: runtime.GOARCH,
|
||||||
@@ -25,6 +25,7 @@ func GetInfo(_ context.Context) *Info {
|
|||||||
Hostname: "wasm-client",
|
Hostname: "wasm-client",
|
||||||
CPUs: runtime.NumCPU(),
|
CPUs: runtime.NumCPU(),
|
||||||
NetbirdVersion: version.NetbirdVersion(),
|
NetbirdVersion: version.NetbirdVersion(),
|
||||||
|
DiskEncryption: detectDiskEncryption(ctx),
|
||||||
}
|
}
|
||||||
|
|
||||||
collectBrowserInfo(info)
|
collectBrowserInfo(info)
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
SystemProductName: si.SystemProductName,
|
SystemProductName: si.SystemProductName,
|
||||||
SystemManufacturer: si.SystemManufacturer,
|
SystemManufacturer: si.SystemManufacturer,
|
||||||
Environment: si.Environment,
|
Environment: si.Environment,
|
||||||
|
DiskEncryption: detectDiskEncryption(ctx),
|
||||||
}
|
}
|
||||||
|
|
||||||
return gio
|
return gio
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
SystemProductName: si.SystemProductName,
|
SystemProductName: si.SystemProductName,
|
||||||
SystemManufacturer: si.SystemManufacturer,
|
SystemManufacturer: si.SystemManufacturer,
|
||||||
Environment: si.Environment,
|
Environment: si.Environment,
|
||||||
|
DiskEncryption: detectDiskEncryption(ctx),
|
||||||
}
|
}
|
||||||
|
|
||||||
addrs, err := networkAddresses()
|
addrs, err := networkAddresses()
|
||||||
|
|||||||
@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
var postUpStatusOutput string
|
var postUpStatusOutput string
|
||||||
if postUpStatus != nil {
|
if postUpStatus != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
postUpStatusOutput = overview.FullDetailSummary()
|
||||||
}
|
}
|
||||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
||||||
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
var preDownStatusOutput string
|
var preDownStatusOutput string
|
||||||
if preDownStatus != nil {
|
if preDownStatus != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
preDownStatusOutput = overview.FullDetailSummary()
|
||||||
}
|
}
|
||||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||||
time.Now().Format(time.RFC3339), params.duration)
|
time.Now().Format(time.RFC3339), params.duration)
|
||||||
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
|||||||
var statusOutput string
|
var statusOutput string
|
||||||
if statusResp != nil {
|
if statusResp != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
statusOutput = overview.FullDetailSummary()
|
||||||
}
|
}
|
||||||
|
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
|
|||||||
@@ -9,20 +9,29 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
netbird "github.com/netbirdio/netbird/client/embed"
|
netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
clientStartTimeout = 30 * time.Second
|
clientStartTimeout = 30 * time.Second
|
||||||
clientStopTimeout = 10 * time.Second
|
clientStopTimeout = 10 * time.Second
|
||||||
|
pingTimeout = 10 * time.Second
|
||||||
defaultLogLevel = "warn"
|
defaultLogLevel = "warn"
|
||||||
defaultSSHDetectionTimeout = 20 * time.Second
|
defaultSSHDetectionTimeout = 20 * time.Second
|
||||||
|
|
||||||
|
icmpEchoRequest = 8
|
||||||
|
icmpCodeEcho = 0
|
||||||
|
pingBufferSize = 1500
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -113,18 +122,45 @@ func createStopMethod(client *netbird.Client) js.Func {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateSSHArgs validates SSH connection arguments
|
||||||
|
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return "", 0, "", js.ValueOf("error: requires host and port")
|
||||||
|
}
|
||||||
|
|
||||||
|
if args[0].Type() != js.TypeString {
|
||||||
|
return "", 0, "", js.ValueOf("host parameter must be a string")
|
||||||
|
}
|
||||||
|
if args[1].Type() != js.TypeNumber {
|
||||||
|
return "", 0, "", js.ValueOf("port parameter must be a number")
|
||||||
|
}
|
||||||
|
|
||||||
|
host = args[0].String()
|
||||||
|
port = args[1].Int()
|
||||||
|
username = "root"
|
||||||
|
|
||||||
|
if len(args) > 2 {
|
||||||
|
if args[2].Type() == js.TypeString && args[2].String() != "" {
|
||||||
|
username = args[2].String()
|
||||||
|
} else if args[2].Type() != js.TypeString {
|
||||||
|
return "", 0, "", js.ValueOf("username parameter must be a string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return host, port, username, js.Undefined()
|
||||||
|
}
|
||||||
|
|
||||||
// createSSHMethod creates the SSH connection method
|
// createSSHMethod creates the SSH connection method
|
||||||
func createSSHMethod(client *netbird.Client) js.Func {
|
func createSSHMethod(client *netbird.Client) js.Func {
|
||||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
if len(args) < 2 {
|
host, port, username, validationErr := validateSSHArgs(args)
|
||||||
return js.ValueOf("error: requires host and port")
|
if !validationErr.IsUndefined() {
|
||||||
|
if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
|
||||||
|
return validationErr
|
||||||
}
|
}
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
host := args[0].String()
|
reject.Invoke(validationErr)
|
||||||
port := args[1].Int()
|
})
|
||||||
username := "root"
|
|
||||||
if len(args) > 2 && args[2].String() != "" {
|
|
||||||
username = args[2].String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var jwtToken string
|
var jwtToken string
|
||||||
@@ -154,6 +190,110 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func performPing(client *netbird.Client, hostname string) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
conn, err := client.Dial(ctx, "ping", hostname)
|
||||||
|
if err != nil {
|
||||||
|
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("failed to close ping connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
icmpData := make([]byte, 8)
|
||||||
|
icmpData[0] = icmpEchoRequest
|
||||||
|
icmpData[1] = icmpCodeEcho
|
||||||
|
|
||||||
|
if _, err := conn.Write(icmpData); err != nil {
|
||||||
|
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, pingBufferSize)
|
||||||
|
if _, err := conn.Read(buf); err != nil {
|
||||||
|
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
latency := time.Since(start)
|
||||||
|
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func performPingTCP(client *netbird.Client, hostname string, port int) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||||
|
start := time.Now()
|
||||||
|
conn, err := client.Dial(ctx, "tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
latency := time.Since(start)
|
||||||
|
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("failed to close TCP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// createPingMethod creates the ping method
|
||||||
|
func createPingMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return js.ValueOf("error: hostname required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if args[0].Type() != js.TypeString {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := args[0].String()
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
performPing(client, hostname)
|
||||||
|
resolve.Invoke(js.Undefined())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createPingTCPMethod creates the pingtcp method
|
||||||
|
func createPingTCPMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return js.ValueOf("error: hostname and port required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if args[0].Type() != js.TypeString {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if args[1].Type() != js.TypeNumber {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
reject.Invoke(js.ValueOf("port parameter must be a number"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := args[0].String()
|
||||||
|
port := args[1].Int()
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
performPingTCP(client, hostname, port)
|
||||||
|
resolve.Invoke(js.Undefined())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// createProxyRequestMethod creates the proxyRequest method
|
// createProxyRequestMethod creates the proxyRequest method
|
||||||
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
@@ -162,6 +302,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request := args[0]
|
request := args[0]
|
||||||
|
if request.Type() != js.TypeObject {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
reject.Invoke(js.ValueOf("request parameter must be an object"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return createPromise(func(resolve, reject js.Value) {
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
response, err := http.ProxyRequest(client, request)
|
response, err := http.ProxyRequest(client, request)
|
||||||
@@ -181,11 +326,145 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
|||||||
return js.ValueOf("error: hostname and port required")
|
return js.ValueOf("error: hostname and port required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args[0].Type() != js.TypeString {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if args[1].Type() != js.TypeString {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
reject.Invoke(js.ValueOf("port parameter must be a string"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
proxy := rdp.NewRDCleanPathProxy(client)
|
proxy := rdp.NewRDCleanPathProxy(client)
|
||||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getStatusOverview is a helper to get the status overview
|
||||||
|
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
|
||||||
|
fullStatus, err := client.Status()
|
||||||
|
if err != nil {
|
||||||
|
return nbstatus.OutputOverview{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pbFullStatus := fullStatus.ToProto()
|
||||||
|
statusResp := &proto.StatusResponse{
|
||||||
|
DaemonVersion: version.NetbirdVersion(),
|
||||||
|
FullStatus: pbFullStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createStatusMethod creates the status method that returns JSON
|
||||||
|
func createStatusMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
overview, err := getStatusOverview(client)
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonStr, err := overview.JSON()
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
|
||||||
|
resolve.Invoke(jsonObj)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createStatusSummaryMethod creates the statusSummary method
|
||||||
|
func createStatusSummaryMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
overview, err := getStatusOverview(client)
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := overview.GeneralSummary(false, false, false, false)
|
||||||
|
js.Global().Get("console").Call("log", summary)
|
||||||
|
resolve.Invoke(js.Undefined())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createStatusDetailMethod creates the statusDetail method
|
||||||
|
func createStatusDetailMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
overview, err := getStatusOverview(client)
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
detail := overview.FullDetailSummary()
|
||||||
|
js.Global().Get("console").Call("log", detail)
|
||||||
|
resolve.Invoke(js.Undefined())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
|
||||||
|
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
syncResp, err := client.GetLatestSyncResponse()
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
options := protojson.MarshalOptions{
|
||||||
|
EmitUnpopulated: true,
|
||||||
|
UseProtoNames: true,
|
||||||
|
AllowPartial: true,
|
||||||
|
}
|
||||||
|
jsonBytes, err := options.Marshal(syncResp)
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
|
||||||
|
resolve.Invoke(jsonObj)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
|
||||||
|
func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return js.ValueOf("error: log level required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if args[0].Type() != js.TypeString {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
reject.Invoke(js.ValueOf("log level parameter must be a string"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logLevel := args[0].String()
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
if err := client.SetLogLevel(logLevel); err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Log level set to: %s", logLevel)
|
||||||
|
resolve.Invoke(js.ValueOf(true))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// createPromise is a helper to create JavaScript promises
|
// createPromise is a helper to create JavaScript promises
|
||||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||||
@@ -237,17 +516,24 @@ func createClientObject(client *netbird.Client) js.Value {
|
|||||||
|
|
||||||
obj["start"] = createStartMethod(client)
|
obj["start"] = createStartMethod(client)
|
||||||
obj["stop"] = createStopMethod(client)
|
obj["stop"] = createStopMethod(client)
|
||||||
|
obj["ping"] = createPingMethod(client)
|
||||||
|
obj["pingtcp"] = createPingTCPMethod(client)
|
||||||
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||||
obj["createSSHConnection"] = createSSHMethod(client)
|
obj["createSSHConnection"] = createSSHMethod(client)
|
||||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||||
|
obj["status"] = createStatusMethod(client)
|
||||||
|
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||||
|
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||||
|
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||||
|
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||||
|
|
||||||
return js.ValueOf(obj)
|
return js.ValueOf(obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||||
func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||||
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
|
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||||
resolve := promiseArgs[0]
|
resolve := promiseArgs[0]
|
||||||
reject := promiseArgs[1]
|
reject := promiseArgs[1]
|
||||||
|
|
||||||
|
|||||||
8
go.mod
8
go.mod
@@ -78,8 +78,8 @@ require (
|
|||||||
github.com/pion/logging v0.2.4
|
github.com/pion/logging v0.2.4
|
||||||
github.com/pion/randutil v0.1.0
|
github.com/pion/randutil v0.1.0
|
||||||
github.com/pion/stun/v2 v2.0.0
|
github.com/pion/stun/v2 v2.0.0
|
||||||
github.com/pion/stun/v3 v3.0.0
|
github.com/pion/stun/v3 v3.1.0
|
||||||
github.com/pion/transport/v3 v3.0.7
|
github.com/pion/transport/v3 v3.1.1
|
||||||
github.com/pion/turn/v3 v3.0.1
|
github.com/pion/turn/v3 v3.0.1
|
||||||
github.com/pkg/sftp v1.13.9
|
github.com/pkg/sftp v1.13.9
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
@@ -241,7 +241,7 @@ require (
|
|||||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||||
github.com/pion/dtls/v3 v3.0.7 // indirect
|
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||||
github.com/pion/transport/v2 v2.2.4 // indirect
|
github.com/pion/transport/v2 v2.2.4 // indirect
|
||||||
github.com/pion/turn/v4 v4.1.1 // indirect
|
github.com/pion/turn/v4 v4.1.1 // indirect
|
||||||
@@ -263,7 +263,7 @@ require (
|
|||||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
github.com/wlynxg/anet v0.0.3 // indirect
|
github.com/wlynxg/anet v0.0.5 // indirect
|
||||||
github.com/yuin/goldmark v1.7.8 // indirect
|
github.com/yuin/goldmark v1.7.8 // indirect
|
||||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||||
|
|||||||
16
go.sum
16
go.sum
@@ -444,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
|
|||||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||||
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
||||||
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
||||||
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
|
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||||
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
|
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||||
@@ -455,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
|||||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||||
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
|
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
|
||||||
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
|
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
|
||||||
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
|
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
|
||||||
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
|
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
|
||||||
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
||||||
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
|
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
|
||||||
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
||||||
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
|
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
|
||||||
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
|
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||||
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
|
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||||
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||||
@@ -574,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
|||||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||||
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
|
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||||
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
|
|||||||
@@ -798,15 +798,15 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
|||||||
"redirectURI": redirectURI,
|
"redirectURI": redirectURI,
|
||||||
"scopes": []string{"openid", "profile", "email"},
|
"scopes": []string{"openid", "profile", "email"},
|
||||||
"insecureEnableGroups": true,
|
"insecureEnableGroups": true,
|
||||||
|
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
|
||||||
|
"insecureSkipEmailVerified": true,
|
||||||
}
|
}
|
||||||
switch cfg.Type {
|
switch cfg.Type {
|
||||||
case "zitadel":
|
case "zitadel":
|
||||||
oidcConfig["getUserInfo"] = true
|
oidcConfig["getUserInfo"] = true
|
||||||
case "entra":
|
case "entra":
|
||||||
oidcConfig["insecureSkipEmailVerified"] = true
|
|
||||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||||
case "okta":
|
case "okta":
|
||||||
oidcConfig["insecureSkipEmailVerified"] = true
|
|
||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "pocketid":
|
case "pocketid":
|
||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
@@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
|
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
@@ -197,6 +198,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
|
|
||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
|
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||||
|
return fmt.Errorf("failed to get account zones: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
||||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||||
@@ -223,9 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
var remotePeerNetworkMap *types.NetworkMap
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
@@ -318,7 +325,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
|
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
@@ -335,12 +342,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountId) {
|
if c.experimentalNetworkMap(accountId) {
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
@@ -434,7 +447,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||||
|
return nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -445,11 +465,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
var networkMap *types.NetworkMap
|
var networkMap *types.NetworkMap
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
@@ -472,7 +492,8 @@ func (c *Controller) getPeerNetworkMapExp(
|
|||||||
accountId string,
|
accountId string,
|
||||||
peerId string,
|
peerId string,
|
||||||
validatedPeers map[string]struct{},
|
validatedPeers map[string]struct{},
|
||||||
customZone nbdns.CustomZone,
|
peersCustomZone nbdns.CustomZone,
|
||||||
|
accountZones []*zones.Zone,
|
||||||
metrics *telemetry.AccountManagerMetrics,
|
metrics *telemetry.AccountManagerMetrics,
|
||||||
) *types.NetworkMap {
|
) *types.NetworkMap {
|
||||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||||
@@ -483,7 +504,7 @@ func (c *Controller) getPeerNetworkMapExp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||||
@@ -798,7 +819,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
|
||||||
|
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -809,11 +838,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
var networkMap *types.NetworkMap
|
var networkMap *types.NetworkMap
|
||||||
|
|
||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||||
} else {
|
} else {
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/server/peer"
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -14,6 +15,7 @@ type Repository interface {
|
|||||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||||
|
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type repository struct {
|
type repository struct {
|
||||||
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
|
|||||||
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
|
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
|
||||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
|
||||||
|
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
}
|
||||||
|
|||||||
13
management/internals/modules/zones/interface.go
Normal file
13
management/internals/modules/zones/interface.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package zones
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
|
||||||
|
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
|
||||||
|
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
|
||||||
|
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
|
||||||
|
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
|
||||||
|
}
|
||||||
161
management/internals/modules/zones/manager/api.go
Normal file
161
management/internals/modules/zones/manager/api.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager zones.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiZones := make([]*api.Zone, 0, len(allZones))
|
||||||
|
for _, zone := range allZones {
|
||||||
|
apiZones = append(apiZones, zone.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, apiZones)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PostApiDnsZonesJSONRequestBody
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zone := new(zones.Zone)
|
||||||
|
zone.FromAPIRequest(&req)
|
||||||
|
|
||||||
|
if err = zone.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneID := mux.Vars(r)["zoneId"]
|
||||||
|
if zoneID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneID := mux.Vars(r)["zoneId"]
|
||||||
|
if zoneID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||||
|
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zone := new(zones.Zone)
|
||||||
|
zone.FromAPIRequest(&req)
|
||||||
|
zone.ID = zoneID
|
||||||
|
|
||||||
|
if err = zone.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneID := mux.Vars(r)["zoneId"]
|
||||||
|
if zoneID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
229
management/internals/modules/zones/manager/manager.go
Normal file
229
management/internals/modules/zones/manager/manager.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
accountManager account.Manager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
dnsDomain string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
accountManager: accountManager,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
dnsDomain: dnsDomain,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups)
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
|
return fmt.Errorf("failed to check existing zone: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if existingZone != nil {
|
||||||
|
return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, groupID := range zone.DistributionGroups {
|
||||||
|
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.CreateZone(ctx, zone); err != nil {
|
||||||
|
return fmt.Errorf("failed to create zone: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta())
|
||||||
|
|
||||||
|
return zone, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get zone: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if zone.Domain != updatedZone.Domain {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
zone.Name = updatedZone.Name
|
||||||
|
zone.Enabled = updatedZone.Enabled
|
||||||
|
zone.EnableSearchDomain = updatedZone.EnableSearchDomain
|
||||||
|
zone.DistributionGroups = updatedZone.DistributionGroups
|
||||||
|
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
for _, groupID := range zone.DistributionGroups {
|
||||||
|
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.UpdateZone(ctx, zone); err != nil {
|
||||||
|
return fmt.Errorf("failed to update zone: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
|
||||||
|
|
||||||
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return zone, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get zone: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var eventsToStore []func()
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get records: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete zone dns records: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.DeleteZone(ctx, accountID, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete zone: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, record := range records {
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
meta := record.EventMeta(zone.ID, zone.Name)
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta())
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range eventsToStore {
|
||||||
|
event()
|
||||||
|
}
|
||||||
|
|
||||||
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error {
|
||||||
|
if m.dnsDomain != "" && m.dnsDomain == domain {
|
||||||
|
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.DNSDomain != "" && settings.DNSDomain == domain {
|
||||||
|
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
553
management/internals/modules/zones/manager/manager_test.go
Normal file
553
management/internals/modules/zones/manager/manager_test.go
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testAccountID = "test-account-id"
|
||||||
|
testUserID = "test-user-id"
|
||||||
|
testZoneID = "test-zone-id"
|
||||||
|
testGroupID = "test-group-id"
|
||||||
|
testDNSDomain = "netbird.selfhosted"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = testStore.SaveAccount(ctx, &types.Account{
|
||||||
|
Id: testAccountID,
|
||||||
|
Groups: map[string]*types.Group{
|
||||||
|
testGroupID: {
|
||||||
|
ID: testGroupID,
|
||||||
|
Name: "Test Group",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockAccountManager := &mock_server.MockAccountManager{}
|
||||||
|
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: testStore,
|
||||||
|
accountManager: mockAccountManager,
|
||||||
|
permissionsManager: mockPermissionsManager,
|
||||||
|
dnsDomain: testDNSDomain,
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID})
|
||||||
|
err := testStore.CreateZone(ctx, zone1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID})
|
||||||
|
err = testStore.CreateZone(ctx, zone2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, result, 2)
|
||||||
|
assert.Equal(t, zone1.ID, result[0].ID)
|
||||||
|
assert.Equal(t, zone2.ID, result[1].ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission validation error", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||||
|
|
||||||
|
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_GetZone(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID})
|
||||||
|
err := testStore.CreateZone(ctx, zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, zone.ID, result.ID)
|
||||||
|
assert.Equal(t, zone.Name, result.Name)
|
||||||
|
assert.Equal(t, zone.Domain, result.Domain)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_CreateZone(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputZone := &zones.Zone{
|
||||||
|
Name: "New Zone",
|
||||||
|
Domain: "new.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: true,
|
||||||
|
DistributionGroups: []string{testGroupID},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
assert.Equal(t, activity.DNSZoneCreated, activityID)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.NotEmpty(t, result.ID)
|
||||||
|
assert.Equal(t, testAccountID, result.AccountID)
|
||||||
|
assert.Equal(t, inputZone.Name, result.Name)
|
||||||
|
assert.Equal(t, inputZone.Domain, result.Domain)
|
||||||
|
assert.Equal(t, inputZone.Enabled, result.Enabled)
|
||||||
|
assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain)
|
||||||
|
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputZone := &zones.Zone{
|
||||||
|
Name: "New Zone",
|
||||||
|
Domain: "new.example.com",
|
||||||
|
DistributionGroups: []string{testGroupID},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid group", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputZone := &zones.Zone{
|
||||||
|
Name: "New Zone",
|
||||||
|
Domain: "new.example.com",
|
||||||
|
DistributionGroups: []string{"invalid-group"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("duplicate domain", func(t *testing.T) {
|
||||||
|
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID})
|
||||||
|
err := testStore.CreateZone(ctx, existingZone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
inputZone := &zones.Zone{
|
||||||
|
Name: "New Zone",
|
||||||
|
Domain: "duplicate.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{testGroupID},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists")
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.AlreadyExists, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("peer DNS domain conflict", func(t *testing.T) {
|
||||||
|
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
account, err := testStore.GetAccount(ctx, testAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
account.Settings.DNSDomain = "peers.example.com"
|
||||||
|
err = testStore.SaveAccount(ctx, account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
inputZone := &zones.Zone{
|
||||||
|
Name: "Test Zone",
|
||||||
|
Domain: "peers.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{testGroupID},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain")
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("default DNS domain conflict", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputZone := &zones.Zone{
|
||||||
|
Name: "Test Zone",
|
||||||
|
Domain: testDNSDomain,
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{testGroupID},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain))
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID})
|
||||||
|
err := testStore.CreateZone(ctx, existingZone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updatedZone := &zones.Zone{
|
||||||
|
ID: existingZone.ID,
|
||||||
|
Name: "Updated Name",
|
||||||
|
Domain: "example.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: true,
|
||||||
|
DistributionGroups: []string{testGroupID},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
storeEventCalled := false
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
storeEventCalled = true
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, existingZone.ID, targetID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
assert.Equal(t, activity.DNSZoneUpdated, activityID)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, updatedZone.Name, result.Name)
|
||||||
|
assert.Equal(t, updatedZone.Enabled, result.Enabled)
|
||||||
|
assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain)
|
||||||
|
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("domain change not allowed", func(t *testing.T) {
|
||||||
|
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||||
|
err := testStore.CreateZone(ctx, existingZone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updatedZone := &zones.Zone{
|
||||||
|
ID: existingZone.ID,
|
||||||
|
Name: "Test Zone",
|
||||||
|
Domain: "different.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: true,
|
||||||
|
DistributionGroups: []string{testGroupID},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), "zone domain cannot be updated")
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
updatedZone := &zones.Zone{
|
||||||
|
ID: testZoneID,
|
||||||
|
Name: "Updated Name",
|
||||||
|
Domain: "example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("zone not found", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
updatedZone := &zones.Zone{
|
||||||
|
ID: "non-existent-zone",
|
||||||
|
Name: "Updated Name",
|
||||||
|
Domain: "example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success with records", func(t *testing.T) {
|
||||||
|
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||||
|
err := testStore.CreateZone(ctx, zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err = testStore.CreateDNSRecord(ctx, record1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||||
|
err = testStore.CreateDNSRecord(ctx, record2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
storeEventCallCount := 0
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
storeEventCallCount++
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, storeEventCallCount)
|
||||||
|
|
||||||
|
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, zoneRecords)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("success without records", func(t *testing.T) {
|
||||||
|
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||||
|
err := testStore.CreateZone(ctx, zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
storeEventCalled := false
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
storeEventCalled = true
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, zone.ID, targetID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
assert.Equal(t, activity.DNSZoneDeleted, activityID)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||||
|
|
||||||
|
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
|
||||||
|
require.Error(t, err)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("zone not found", func(t *testing.T) {
|
||||||
|
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
13
management/internals/modules/zones/records/interface.go
Normal file
13
management/internals/modules/zones/records/interface.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package records
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error)
|
||||||
|
GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error)
|
||||||
|
CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
|
||||||
|
UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
|
||||||
|
DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error
|
||||||
|
}
|
||||||
191
management/internals/modules/zones/records/manager/api.go
Normal file
191
management/internals/modules/zones/records/manager/api.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager records.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneID := mux.Vars(r)["zoneId"]
|
||||||
|
if zoneID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiRecords := make([]*api.DNSRecord, 0, len(allRecords))
|
||||||
|
for _, record := range allRecords {
|
||||||
|
apiRecords = append(apiRecords, record.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, apiRecords)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneID := mux.Vars(r)["zoneId"]
|
||||||
|
if zoneID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
record := new(records.Record)
|
||||||
|
record.FromAPIRequest(&req)
|
||||||
|
|
||||||
|
if err = record.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneID := mux.Vars(r)["zoneId"]
|
||||||
|
if zoneID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
recordID := mux.Vars(r)["recordId"]
|
||||||
|
if recordID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneID := mux.Vars(r)["zoneId"]
|
||||||
|
if zoneID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
recordID := mux.Vars(r)["recordId"]
|
||||||
|
if recordID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||||
|
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
record := new(records.Record)
|
||||||
|
record.FromAPIRequest(&req)
|
||||||
|
record.ID = recordID
|
||||||
|
|
||||||
|
if err = record.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneID := mux.Vars(r)["zoneId"]
|
||||||
|
if zoneID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
recordID := mux.Vars(r)["recordId"]
|
||||||
|
if recordID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
236
management/internals/modules/zones/records/manager/manager.go
Normal file
236
management/internals/modules/zones/records/manager/manager.go
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
accountManager account.Manager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
accountManager: accountManager,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var zone *zones.Zone
|
||||||
|
|
||||||
|
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get zone: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateRecordConflicts(ctx, transaction, zone, record)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.CreateDNSRecord(ctx, record); err != nil {
|
||||||
|
return fmt.Errorf("failed to create dns record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := record.EventMeta(zone.ID, zone.Name)
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
|
||||||
|
|
||||||
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var zone *zones.Zone
|
||||||
|
var record *records.Record
|
||||||
|
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get zone: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content
|
||||||
|
|
||||||
|
record.Name = updatedRecord.Name
|
||||||
|
record.Type = updatedRecord.Type
|
||||||
|
record.Content = updatedRecord.Content
|
||||||
|
record.TTL = updatedRecord.TTL
|
||||||
|
|
||||||
|
if hasChanges {
|
||||||
|
if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.UpdateDNSRecord(ctx, record); err != nil {
|
||||||
|
return fmt.Errorf("failed to update dns record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := record.EventMeta(zone.ID, zone.Name)
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
|
||||||
|
|
||||||
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var record *records.Record
|
||||||
|
var zone *zones.Zone
|
||||||
|
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get zone: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete dns record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := record.EventMeta(zone.ID, zone.Name)
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
|
||||||
|
|
||||||
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateRecordConflicts checks for duplicate records and CNAME conflicts
|
||||||
|
func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error {
|
||||||
|
if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) {
|
||||||
|
return status.Errorf(status.InvalidArgument, "record name does not belong to zone")
|
||||||
|
}
|
||||||
|
|
||||||
|
existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check existing records: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, existing := range existingRecords {
|
||||||
|
if existing.ID == record.ID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if existing.Type == record.Type && existing.Content == record.Content {
|
||||||
|
return status.Errorf(status.AlreadyExists, "identical record already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME {
|
||||||
|
return status.Errorf(status.InvalidArgument,
|
||||||
|
"An A, AAAA, or CNAME record with name %s already exists", record.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,573 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testAccountID = "test-account-id"
|
||||||
|
testUserID = "test-user-id"
|
||||||
|
testRecordID = "test-record-id"
|
||||||
|
testGroupID = "test-group-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = testStore.SaveAccount(ctx, &types.Account{
|
||||||
|
Id: testAccountID,
|
||||||
|
Groups: map[string]*types.Group{
|
||||||
|
testGroupID: {
|
||||||
|
ID: testGroupID,
|
||||||
|
Name: "Test Group",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||||
|
err = testStore.CreateZone(ctx, zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockAccountManager := &mock_server.MockAccountManager{}
|
||||||
|
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: testStore,
|
||||||
|
accountManager: mockAccountManager,
|
||||||
|
permissionsManager: mockPermissionsManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err := testStore.CreateDNSRecord(ctx, record1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||||
|
err = testStore.CreateDNSRecord(ctx, record2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, result, 2)
|
||||||
|
assert.Equal(t, record1.ID, result[0].ID)
|
||||||
|
assert.Equal(t, record2.ID, result[1].ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission validation error", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||||
|
|
||||||
|
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_GetRecord(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err := testStore.CreateDNSRecord(ctx, record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, record.ID, result.ID)
|
||||||
|
assert.Equal(t, record.Name, result.Name)
|
||||||
|
assert.Equal(t, record.Type, result.Type)
|
||||||
|
assert.Equal(t, record.Content, result.Content)
|
||||||
|
assert.Equal(t, record.TTL, result.TTL)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success - A record", func(t *testing.T) {
|
||||||
|
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputRecord := &records.Record{
|
||||||
|
Name: "api.example.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.NotEmpty(t, result.ID)
|
||||||
|
assert.Equal(t, testAccountID, result.AccountID)
|
||||||
|
assert.Equal(t, zone.ID, result.ZoneID)
|
||||||
|
assert.Equal(t, inputRecord.Name, result.Name)
|
||||||
|
assert.Equal(t, inputRecord.Type, result.Type)
|
||||||
|
assert.Equal(t, inputRecord.Content, result.Content)
|
||||||
|
assert.Equal(t, inputRecord.TTL, result.TTL)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("success - AAAA record", func(t *testing.T) {
|
||||||
|
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputRecord := &records.Record{
|
||||||
|
Name: "ipv6.example.com",
|
||||||
|
Type: records.RecordTypeAAAA,
|
||||||
|
Content: "2001:db8::1",
|
||||||
|
TTL: 600,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, inputRecord.Type, result.Type)
|
||||||
|
assert.Equal(t, inputRecord.Content, result.Content)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("success - CNAME record", func(t *testing.T) {
|
||||||
|
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputRecord := &records.Record{
|
||||||
|
Name: "www.example.com",
|
||||||
|
Type: records.RecordTypeCNAME,
|
||||||
|
Content: "example.com",
|
||||||
|
TTL: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, inputRecord.Type, result.Type)
|
||||||
|
assert.Equal(t, inputRecord.Content, result.Content)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputRecord := &records.Record{
|
||||||
|
Name: "api.example.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("record name not in zone", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
inputRecord := &records.Record{
|
||||||
|
Name: "api.different.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), "does not belong to zone")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("duplicate record", func(t *testing.T) {
|
||||||
|
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
inputRecord := &records.Record{
|
||||||
|
Name: "api.example.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), "identical record already exists")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
|
||||||
|
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
inputRecord := &records.Record{
|
||||||
|
Name: "api.example.com",
|
||||||
|
Type: records.RecordTypeCNAME,
|
||||||
|
Content: "example.com",
|
||||||
|
TTL: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), "already exists")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updatedRecord := &records.Record{
|
||||||
|
ID: existingRecord.ID,
|
||||||
|
Name: "api.example.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.100", // Changed IP
|
||||||
|
TTL: 600, // Changed TTL
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
storeEventCalled := false
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
storeEventCalled = true
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, existingRecord.ID, targetID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
assert.Equal(t, activity.DNSRecordUpdated, activityID)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, updatedRecord.Content, result.Content)
|
||||||
|
assert.Equal(t, updatedRecord.TTL, result.TTL)
|
||||||
|
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update only TTL - no validation", func(t *testing.T) {
|
||||||
|
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updatedRecord := &records.Record{
|
||||||
|
ID: existingRecord.ID,
|
||||||
|
Name: existingRecord.Name,
|
||||||
|
Type: existingRecord.Type,
|
||||||
|
Content: existingRecord.Content,
|
||||||
|
TTL: 600, // Only TTL changed
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
// Event should be stored
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, 600, result.TTL)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
updatedRecord := &records.Record{
|
||||||
|
ID: testRecordID,
|
||||||
|
Name: "api.example.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.100",
|
||||||
|
TTL: 600,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("record not found", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
updatedRecord := &records.Record{
|
||||||
|
ID: "non-existent-record",
|
||||||
|
Name: "api.example.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.100",
|
||||||
|
TTL: 600,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update creates duplicate", func(t *testing.T) {
|
||||||
|
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err := testStore.CreateDNSRecord(ctx, record1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||||
|
err = testStore.CreateDNSRecord(ctx, record2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updatedRecord := &records.Record{
|
||||||
|
ID: record2.ID,
|
||||||
|
Name: "api.example.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), "identical record already exists")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err := testStore.CreateDNSRecord(ctx, record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
storeEventCalled := false
|
||||||
|
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||||
|
storeEventCalled = true
|
||||||
|
assert.Equal(t, testUserID, initiatorID)
|
||||||
|
assert.Equal(t, record.ID, targetID)
|
||||||
|
assert.Equal(t, testAccountID, accountID)
|
||||||
|
assert.Equal(t, activity.DNSRecordDeleted, activityID)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||||
|
|
||||||
|
_, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("permission denied", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||||
|
Return(false, nil)
|
||||||
|
|
||||||
|
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||||
|
require.Error(t, err)
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("record not found", func(t *testing.T) {
|
||||||
|
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||||
|
defer cleanup()
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockPermissionsManager.EXPECT().
|
||||||
|
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||||
|
Return(true, nil)
|
||||||
|
|
||||||
|
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
129
management/internals/modules/zones/records/record.go
Normal file
129
management/internals/modules/zones/records/record.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package records
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RecordType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RecordTypeA RecordType = "A"
|
||||||
|
RecordTypeAAAA RecordType = "AAAA"
|
||||||
|
RecordTypeCNAME RecordType = "CNAME"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Record struct {
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
ZoneID string `gorm:"index"`
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
Type RecordType
|
||||||
|
Content string
|
||||||
|
TTL int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record {
|
||||||
|
return &Record{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: accountID,
|
||||||
|
ZoneID: zoneID,
|
||||||
|
Name: name,
|
||||||
|
Type: recordType,
|
||||||
|
Content: content,
|
||||||
|
TTL: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Record) ToAPIResponse() *api.DNSRecord {
|
||||||
|
recordType := api.DNSRecordType(r.Type)
|
||||||
|
return &api.DNSRecord{
|
||||||
|
Id: r.ID,
|
||||||
|
Name: r.Name,
|
||||||
|
Type: recordType,
|
||||||
|
Content: r.Content,
|
||||||
|
Ttl: r.TTL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) {
|
||||||
|
r.Name = req.Name
|
||||||
|
r.Type = RecordType(req.Type)
|
||||||
|
r.Content = req.Content
|
||||||
|
r.TTL = req.Ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Record) Validate() error {
|
||||||
|
if r.Name == "" {
|
||||||
|
return errors.New("record name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !util.IsValidDomain(r.Name) {
|
||||||
|
return errors.New("invalid record name format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Type == "" {
|
||||||
|
return errors.New("record type is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Type {
|
||||||
|
case RecordTypeA:
|
||||||
|
if err := validateIPv4(r.Content); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case RecordTypeAAAA:
|
||||||
|
if err := validateIPv6(r.Content); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case RecordTypeCNAME:
|
||||||
|
if !util.IsValidDomain(r.Content) {
|
||||||
|
return errors.New("invalid CNAME record format")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors.New("invalid record type, must be A, AAAA, or CNAME")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.TTL < 0 {
|
||||||
|
return errors.New("TTL cannot be negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Record) EventMeta(zoneID, zoneName string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"name": r.Name,
|
||||||
|
"type": string(r.Type),
|
||||||
|
"content": r.Content,
|
||||||
|
"ttl": r.TTL,
|
||||||
|
"zone_id": zoneID,
|
||||||
|
"zone_name": zoneName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateIPv4(content string) error {
|
||||||
|
if content == "" {
|
||||||
|
return errors.New("A record is required") //nolint:staticcheck
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(content)
|
||||||
|
if ip == nil || ip.To4() == nil {
|
||||||
|
return errors.New("A record must be a valid IPv4 address") //nolint:staticcheck
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateIPv6(content string) error {
|
||||||
|
if content == "" {
|
||||||
|
return errors.New("AAAA record is required")
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(content)
|
||||||
|
if ip == nil || ip.To4() != nil {
|
||||||
|
return errors.New("AAAA record must be a valid IPv6 address")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
89
management/internals/modules/zones/zone.go
Normal file
89
management/internals/modules/zones/zone.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package zones
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Zone struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
Name string
|
||||||
|
Domain string
|
||||||
|
Enabled bool
|
||||||
|
EnableSearchDomain bool
|
||||||
|
DistributionGroups []string `gorm:"serializer:json"`
|
||||||
|
Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone {
|
||||||
|
return &Zone{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: name,
|
||||||
|
Domain: domain,
|
||||||
|
Enabled: enabled,
|
||||||
|
EnableSearchDomain: enableSearchDomain,
|
||||||
|
DistributionGroups: distributionGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *Zone) ToAPIResponse() *api.Zone {
|
||||||
|
apiRecords := make([]api.DNSRecord, 0, len(z.Records))
|
||||||
|
for _, record := range z.Records {
|
||||||
|
if apiRecord := record.ToAPIResponse(); apiRecord != nil {
|
||||||
|
apiRecords = append(apiRecords, *apiRecord)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.Zone{
|
||||||
|
DistributionGroups: z.DistributionGroups,
|
||||||
|
Domain: z.Domain,
|
||||||
|
EnableSearchDomain: z.EnableSearchDomain,
|
||||||
|
Enabled: z.Enabled,
|
||||||
|
Id: z.ID,
|
||||||
|
Name: z.Name,
|
||||||
|
Records: apiRecords,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *Zone) FromAPIRequest(req *api.ZoneRequest) {
|
||||||
|
z.Name = req.Name
|
||||||
|
z.Domain = req.Domain
|
||||||
|
z.EnableSearchDomain = req.EnableSearchDomain
|
||||||
|
z.DistributionGroups = req.DistributionGroups
|
||||||
|
|
||||||
|
enabled := true
|
||||||
|
if req.Enabled != nil {
|
||||||
|
enabled = *req.Enabled
|
||||||
|
}
|
||||||
|
z.Enabled = enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *Zone) Validate() error {
|
||||||
|
if z.Name == "" {
|
||||||
|
return errors.New("zone name is required")
|
||||||
|
}
|
||||||
|
if len(z.Name) > 255 {
|
||||||
|
return errors.New("zone name exceeds maximum length of 255 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !util.IsValidDomain(z.Domain) {
|
||||||
|
return errors.New("invalid zone domain format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(z.DistributionGroups) == 0 {
|
||||||
|
return errors.New("at least one distribution group is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *Zone) EventMeta() map[string]any {
|
||||||
|
return map[string]any{"name": z.Name, "domain": z.Domain}
|
||||||
|
}
|
||||||
@@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
|
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
@@ -158,3 +162,15 @@ func (s *BaseServer) NetworksManager() networks.Manager {
|
|||||||
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ZonesManager() zones.Manager {
|
||||||
|
return Create(s, func() zones.Manager {
|
||||||
|
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) RecordsManager() records.Manager {
|
||||||
|
return Create(s, func() records.Manager {
|
||||||
|
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -376,6 +376,7 @@ func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
|||||||
protoZone := &proto.CustomZone{
|
protoZone := &proto.CustomZone{
|
||||||
Domain: zone.Domain,
|
Domain: zone.Domain,
|
||||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||||
|
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||||
NonAuthoritative: zone.NonAuthoritative,
|
NonAuthoritative: zone.NonAuthoritative,
|
||||||
}
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
@@ -433,9 +434,16 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi
|
|||||||
if config.CLIAuthAudience != "" {
|
if config.CLIAuthAudience != "" {
|
||||||
audience = config.CLIAuthAudience
|
audience = config.CLIAuthAudience
|
||||||
}
|
}
|
||||||
|
|
||||||
|
audiences := []string{config.AuthAudience}
|
||||||
|
if config.CLIAuthAudience != "" && config.CLIAuthAudience != config.AuthAudience {
|
||||||
|
audiences = append(audiences, config.CLIAuthAudience)
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.JWTConfig{
|
return &proto.JWTConfig{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
Audience: audience,
|
Audience: audience,
|
||||||
|
Audiences: audiences,
|
||||||
KeysLocation: keysLocation,
|
KeysLocation: keysLocation,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||||
@@ -148,3 +151,51 @@ func generateTestData(size int) nbdns.Config {
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
authAudience string
|
||||||
|
cliAuthAudience string
|
||||||
|
expectedAudiences []string
|
||||||
|
expectedAudience string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "only_auth_audience",
|
||||||
|
authAudience: "dashboard-aud",
|
||||||
|
cliAuthAudience: "",
|
||||||
|
expectedAudiences: []string{"dashboard-aud"},
|
||||||
|
expectedAudience: "dashboard-aud",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both_audiences_different",
|
||||||
|
authAudience: "dashboard-aud",
|
||||||
|
cliAuthAudience: "cli-aud",
|
||||||
|
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
|
||||||
|
expectedAudience: "cli-aud",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both_audiences_same",
|
||||||
|
authAudience: "same-aud",
|
||||||
|
cliAuthAudience: "same-aud",
|
||||||
|
expectedAudiences: []string{"same-aud"},
|
||||||
|
expectedAudience: "same-aud",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
config := &nbconfig.HttpServerConfig{
|
||||||
|
AuthIssuer: "https://issuer.example.com",
|
||||||
|
AuthAudience: tc.authAudience,
|
||||||
|
CLIAuthAudience: tc.cliAuthAudience,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := buildJWTConfig(config, nil)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
|
||||||
|
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -470,6 +470,16 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
diskEncryptionVolumes := make([]nbpeer.DiskEncryptionVolume, 0)
|
||||||
|
if meta.GetDiskEncryption() != nil {
|
||||||
|
for _, vol := range meta.GetDiskEncryption().GetVolumes() {
|
||||||
|
diskEncryptionVolumes = append(diskEncryptionVolumes, nbpeer.DiskEncryptionVolume{
|
||||||
|
Path: vol.GetPath(),
|
||||||
|
Encrypted: vol.GetEncrypted(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nbpeer.PeerSystemMeta{
|
return nbpeer.PeerSystemMeta{
|
||||||
Hostname: meta.GetHostname(),
|
Hostname: meta.GetHostname(),
|
||||||
GoOS: meta.GetGoOS(),
|
GoOS: meta.GetGoOS(),
|
||||||
@@ -501,6 +511,9 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
|||||||
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
|
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
|
||||||
},
|
},
|
||||||
Files: files,
|
Files: files,
|
||||||
|
DiskEncryption: nbpeer.DiskEncryptionInfo{
|
||||||
|
Volumes: diskEncryptionVolumes,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
|
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,7 +388,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return newSettings, nil
|
return newSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||||
halfYearLimit := 180 * 24 * time.Hour
|
halfYearLimit := 180 * 24 * time.Hour
|
||||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||||
@@ -402,6 +402,18 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, new
|
|||||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if newSettings.DNSDomain != oldSettings.DNSDomain && newSettings.DNSDomain != "" {
|
||||||
|
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, newSettings.DNSDomain)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
|
return fmt.Errorf("failed to check existing zone: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if existingZone != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "peer DNS domain %s conflicts with existing custom DNS zone", newSettings.DNSDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
|
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
@@ -397,7 +398,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||||
}
|
}
|
||||||
@@ -1676,7 +1677,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||||
|
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeIDs := make(map[route.ID]struct{}, 2)
|
routeIDs := make(map[route.ID]struct{}, 2)
|
||||||
@@ -1686,7 +1687,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
|||||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||||
|
|
||||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||||
|
|
||||||
assert.Len(t, emptyRoutes, 0)
|
assert.Len(t, emptyRoutes, 0)
|
||||||
}
|
}
|
||||||
@@ -2095,6 +2096,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_UpdateAccountSettings_DNSDomainConflict(t *testing.T) {
|
||||||
|
manager, _, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||||
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = manager.Store.CreateZone(ctx, &zones.Zone{
|
||||||
|
ID: "test-zone-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: "Test Zone",
|
||||||
|
Domain: "custom.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{},
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "unable to create custom DNS zone")
|
||||||
|
|
||||||
|
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
|
||||||
|
DNSDomain: "custom.example.com",
|
||||||
|
PeerLoginExpiration: time.Hour,
|
||||||
|
PeerLoginExpirationEnabled: false,
|
||||||
|
Extra: &types.ExtraSettings{},
|
||||||
|
})
|
||||||
|
require.Error(t, err, "expecting to fail when DNS domain conflicts with custom zone")
|
||||||
|
assert.Contains(t, err.Error(), "conflicts with existing custom DNS zone")
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_GetExpiredPeers(t *testing.T) {
|
func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -187,6 +187,14 @@ const (
|
|||||||
IdentityProviderUpdated Activity = 94
|
IdentityProviderUpdated Activity = 94
|
||||||
IdentityProviderDeleted Activity = 95
|
IdentityProviderDeleted Activity = 95
|
||||||
|
|
||||||
|
DNSZoneCreated Activity = 96
|
||||||
|
DNSZoneUpdated Activity = 97
|
||||||
|
DNSZoneDeleted Activity = 98
|
||||||
|
|
||||||
|
DNSRecordCreated Activity = 99
|
||||||
|
DNSRecordUpdated Activity = 100
|
||||||
|
DNSRecordDeleted Activity = 101
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -303,6 +311,14 @@ var activityMap = map[Activity]Code{
|
|||||||
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
|
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
|
||||||
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
|
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
|
||||||
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
|
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
|
||||||
|
|
||||||
|
DNSZoneCreated: {"DNS zone created", "dns.zone.create"},
|
||||||
|
DNSZoneUpdated: {"DNS zone updated", "dns.zone.update"},
|
||||||
|
DNSZoneDeleted: {"DNS zone deleted", "dns.zone.delete"},
|
||||||
|
|
||||||
|
DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"},
|
||||||
|
DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"},
|
||||||
|
DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
25
management/server/cache/idp.go
vendored
25
management/server/cache/idp.go
vendored
@@ -26,6 +26,8 @@ type UserDataCache interface {
|
|||||||
Get(ctx context.Context, key string) (*idp.UserData, error)
|
Get(ctx context.Context, key string) (*idp.UserData, error)
|
||||||
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
|
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
|
||||||
Delete(ctx context.Context, key string) error
|
Delete(ctx context.Context, key string) error
|
||||||
|
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
|
||||||
|
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
|
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
|
||||||
@@ -51,6 +53,29 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
|
|||||||
return u.cache.Delete(ctx, key)
|
return u.cache.Delete(ctx, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
|
||||||
|
var users []*idp.UserData
|
||||||
|
v, err := u.cache.Get(ctx, key, &users)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := v.(type) {
|
||||||
|
case []*idp.UserData:
|
||||||
|
return v, nil
|
||||||
|
case *[]*idp.UserData:
|
||||||
|
return *v, nil
|
||||||
|
case []byte:
|
||||||
|
return unmarshalUserData(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unexpected type: %T", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
|
||||||
|
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
|
||||||
|
}
|
||||||
|
|
||||||
// NewUserDataCache creates a new UserDataCacheImpl object.
|
// NewUserDataCache creates a new UserDataCacheImpl object.
|
||||||
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
|
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
|
||||||
simpleCache := cache.New[any](store)
|
simpleCache := cache.New[any](store)
|
||||||
|
|||||||
@@ -15,7 +15,10 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
|
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
|
||||||
@@ -56,7 +59,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -138,6 +141,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
dns.AddEndpoints(accountManager, router)
|
dns.AddEndpoints(accountManager, router)
|
||||||
events.AddEndpoints(accountManager, router)
|
events.AddEndpoints(accountManager, router)
|
||||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||||
|
zonesManager.RegisterEndpoints(router, zManager)
|
||||||
|
recordsManager.RegisterEndpoints(router, rManager)
|
||||||
idp.AddEndpoints(accountManager, router)
|
idp.AddEndpoints(accountManager, router)
|
||||||
instance.AddEndpoints(instanceManager, router)
|
instance.AddEndpoints(instanceManager, router)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
@@ -298,8 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
m.patUsageTracker.IncrementUsage(token)
|
m.patUsageTracker.IncrementUsage(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.rateLimiter != nil {
|
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||||
if !m.rateLimiter.Allow(token) {
|
if !m.rateLimiter.Allow(token) {
|
||||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||||
}
|
}
|
||||||
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isTerraformRequest(r *http.Request) bool {
|
||||||
|
ua := strings.ToLower(r.Header.Get("User-Agent"))
|
||||||
|
return strings.Contains(ua, "terraform")
|
||||||
|
}
|
||||||
|
|
||||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||||
// the JWT token from the Authorization header.
|
// the JWT token from the Authorization header.
|
||||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||||
|
|||||||
@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
handler.ServeHTTP(rec, req)
|
handler.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
|
||||||
|
rateLimitConfig := &RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 1,
|
||||||
|
Burst: 1,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
LimiterTTL: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
authMiddleware := NewAuthMiddleware(
|
||||||
|
mockAuth,
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
|
return userAuth.AccountId, userAuth.UserId, nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
|
return &types.User{}, nil
|
||||||
|
},
|
||||||
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Test various Terraform user agent formats
|
||||||
|
terraformUserAgents := []string{
|
||||||
|
"Terraform/1.5.0",
|
||||||
|
"terraform/1.0.0",
|
||||||
|
"Terraform-Provider/2.0.0",
|
||||||
|
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, userAgent := range terraformUserAgents {
|
||||||
|
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
|
||||||
|
successCount := 0
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+PAT)
|
||||||
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
if rec.Code == http.StatusOK {
|
||||||
|
successCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
|
||||||
|
rateLimitConfig := &RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 1,
|
||||||
|
Burst: 1,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
LimiterTTL: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
authMiddleware := NewAuthMiddleware(
|
||||||
|
mockAuth,
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
|
return userAuth.AccountId, userAuth.UserId, nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
|
return &types.User{}, nil
|
||||||
|
},
|
||||||
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+PAT)
|
||||||
|
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||||
|
|
||||||
|
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+PAT)
|
||||||
|
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
|
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
@@ -93,8 +95,10 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
routersManagerMock := routers.NewManagerMock()
|
routersManagerMock := routers.NewManagerMock()
|
||||||
groupsManagerMock := groups.NewManagerMock()
|
groupsManagerMock := groups.NewManagerMock()
|
||||||
peersManager := peers.NewManager(store, permissionsManager)
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,9 +136,10 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.AuthIssuer == "" {
|
if config.AuthIssuer == "" {
|
||||||
|
|||||||
@@ -48,13 +48,12 @@ type AuthentikCredentials struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthentikManager creates a new instance of the AuthentikManager.
|
// NewAuthentikManager creates a new instance of the AuthentikManager.
|
||||||
func NewAuthentikManager(config AuthentikClientConfig,
|
func NewAuthentikManager(config AuthentikClientConfig, appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
||||||
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,9 +58,10 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.ClientID == "" {
|
if config.ClientID == "" {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
@@ -49,9 +48,10 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
|
|||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.CustomerID == "" {
|
if config.CustomerID == "" {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
||||||
|
|
||||||
@@ -46,9 +45,10 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
|||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.APIToken == "" {
|
if config.APIToken == "" {
|
||||||
|
|||||||
@@ -63,9 +63,10 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
|
|||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.ClientID == "" {
|
if config.ClientID == "" {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||||
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
||||||
@@ -45,7 +44,7 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
|
|||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
@@ -88,9 +87,10 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet
|
|||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
if config.ManagementEndpoint == "" {
|
if config.ManagementEndpoint == "" {
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -69,3 +71,24 @@ func baseURL(rawURL string) string {
|
|||||||
|
|
||||||
return parsedURL.Scheme + "://" + parsedURL.Host
|
return parsedURL.Scheme + "://" + parsedURL.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Provides the env variable name for use with idpTimeout function
|
||||||
|
idpTimeoutEnv = "NB_IDP_TIMEOUT"
|
||||||
|
// Sets the defaultTimeout to 10s.
|
||||||
|
defaultTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// idpTimeout returns a timeout value for the IDP
|
||||||
|
func idpTimeout() time.Duration {
|
||||||
|
timeoutStr, ok := os.LookupEnv(idpTimeoutEnv)
|
||||||
|
if !ok || timeoutStr == "" {
|
||||||
|
return defaultTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout, err := time.ParseDuration(timeoutStr)
|
||||||
|
if err != nil {
|
||||||
|
return defaultTimeout
|
||||||
|
}
|
||||||
|
return timeout
|
||||||
|
}
|
||||||
|
|||||||
@@ -164,9 +164,10 @@ func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetri
|
|||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: idpTimeout(),
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
helper := JsonParser{}
|
helper := JsonParser{}
|
||||||
|
|
||||||
hasPAT := config.PAT != ""
|
hasPAT := config.PAT != ""
|
||||||
|
|||||||
@@ -95,6 +95,27 @@ type File struct {
|
|||||||
ProcessIsRunning bool
|
ProcessIsRunning bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DiskEncryptionVolume represents encryption status of a volume.
|
||||||
|
type DiskEncryptionVolume struct {
|
||||||
|
Path string
|
||||||
|
Encrypted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DiskEncryptionInfo holds encryption info for all volumes.
|
||||||
|
type DiskEncryptionInfo struct {
|
||||||
|
Volumes []DiskEncryptionVolume `gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEncrypted returns true if the volume at path is encrypted.
|
||||||
|
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
|
||||||
|
for _, v := range d.Volumes {
|
||||||
|
if v.Path == path {
|
||||||
|
return v.Encrypted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Flags defines a set of options to control feature behavior
|
// Flags defines a set of options to control feature behavior
|
||||||
type Flags struct {
|
type Flags struct {
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
@@ -130,6 +151,7 @@ type PeerSystemMeta struct { //nolint:revive
|
|||||||
Environment Environment `gorm:"serializer:json"`
|
Environment Environment `gorm:"serializer:json"`
|
||||||
Flags Flags `gorm:"serializer:json"`
|
Flags Flags `gorm:"serializer:json"`
|
||||||
Files []File `gorm:"serializer:json"`
|
Files []File `gorm:"serializer:json"`
|
||||||
|
DiskEncryption DiskEncryptionInfo `gorm:"serializer:json"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||||
@@ -159,6 +181,19 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sort.Slice(p.DiskEncryption.Volumes, func(i, j int) bool {
|
||||||
|
return p.DiskEncryption.Volumes[i].Path < p.DiskEncryption.Volumes[j].Path
|
||||||
|
})
|
||||||
|
sort.Slice(other.DiskEncryption.Volumes, func(i, j int) bool {
|
||||||
|
return other.DiskEncryption.Volumes[i].Path < other.DiskEncryption.Volumes[j].Path
|
||||||
|
})
|
||||||
|
equalDiskEncryption := slices.EqualFunc(p.DiskEncryption.Volumes, other.DiskEncryption.Volumes, func(vol DiskEncryptionVolume, oVol DiskEncryptionVolume) bool {
|
||||||
|
return vol.Path == oVol.Path && vol.Encrypted == oVol.Encrypted
|
||||||
|
})
|
||||||
|
if !equalDiskEncryption {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
return p.Hostname == other.Hostname &&
|
return p.Hostname == other.Hostname &&
|
||||||
p.GoOS == other.GoOS &&
|
p.GoOS == other.GoOS &&
|
||||||
p.Kernel == other.Kernel &&
|
p.Kernel == other.Kernel &&
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ const (
|
|||||||
GeoLocationCheckName = "GeoLocationCheck"
|
GeoLocationCheckName = "GeoLocationCheck"
|
||||||
PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
|
PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
|
||||||
ProcessCheckName = "ProcessCheck"
|
ProcessCheckName = "ProcessCheck"
|
||||||
|
DiskEncryptionCheckName = "DiskEncryptionCheck"
|
||||||
|
|
||||||
CheckActionAllow string = "allow"
|
CheckActionAllow string = "allow"
|
||||||
CheckActionDeny string = "deny"
|
CheckActionDeny string = "deny"
|
||||||
@@ -58,6 +59,7 @@ type ChecksDefinition struct {
|
|||||||
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
|
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
|
||||||
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
|
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
|
||||||
ProcessCheck *ProcessCheck `json:",omitempty"`
|
ProcessCheck *ProcessCheck `json:",omitempty"`
|
||||||
|
DiskEncryptionCheck *DiskEncryptionCheck `json:",omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy returns a copy of a checks definition.
|
// Copy returns a copy of a checks definition.
|
||||||
@@ -110,6 +112,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
|
|||||||
}
|
}
|
||||||
copy(cdCopy.ProcessCheck.Processes, processCheck.Processes)
|
copy(cdCopy.ProcessCheck.Processes, processCheck.Processes)
|
||||||
}
|
}
|
||||||
|
if cd.DiskEncryptionCheck != nil {
|
||||||
|
cdCopy.DiskEncryptionCheck = &DiskEncryptionCheck{
|
||||||
|
LinuxPath: cd.DiskEncryptionCheck.LinuxPath,
|
||||||
|
DarwinPath: cd.DiskEncryptionCheck.DarwinPath,
|
||||||
|
WindowsPath: cd.DiskEncryptionCheck.WindowsPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
return cdCopy
|
return cdCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,6 +162,9 @@ func (pc *Checks) GetChecks() []Check {
|
|||||||
if pc.Checks.ProcessCheck != nil {
|
if pc.Checks.ProcessCheck != nil {
|
||||||
checks = append(checks, pc.Checks.ProcessCheck)
|
checks = append(checks, pc.Checks.ProcessCheck)
|
||||||
}
|
}
|
||||||
|
if pc.Checks.DiskEncryptionCheck != nil {
|
||||||
|
checks = append(checks, pc.Checks.DiskEncryptionCheck)
|
||||||
|
}
|
||||||
return checks
|
return checks
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,6 +220,10 @@ func buildPostureCheck(postureChecksID string, name string, description string,
|
|||||||
postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck)
|
postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if diskEncryptionCheck := checks.DiskEncryptionCheck; diskEncryptionCheck != nil {
|
||||||
|
postureChecks.Checks.DiskEncryptionCheck = toDiskEncryptionCheck(diskEncryptionCheck)
|
||||||
|
}
|
||||||
|
|
||||||
return &postureChecks, nil
|
return &postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,6 +258,10 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck {
|
|||||||
checks.ProcessCheck = toProcessCheckResponse(pc.Checks.ProcessCheck)
|
checks.ProcessCheck = toProcessCheckResponse(pc.Checks.ProcessCheck)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pc.Checks.DiskEncryptionCheck != nil {
|
||||||
|
checks.DiskEncryptionCheck = toDiskEncryptionCheckResponse(pc.Checks.DiskEncryptionCheck)
|
||||||
|
}
|
||||||
|
|
||||||
return &api.PostureCheck{
|
return &api.PostureCheck{
|
||||||
Id: pc.ID,
|
Id: pc.ID,
|
||||||
Name: pc.Name,
|
Name: pc.Name,
|
||||||
@@ -386,3 +406,25 @@ func toProcessCheck(check *api.ProcessCheck) *ProcessCheck {
|
|||||||
Processes: processes,
|
Processes: processes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toDiskEncryptionCheck(check *api.DiskEncryptionCheck) *DiskEncryptionCheck {
|
||||||
|
d := &DiskEncryptionCheck{}
|
||||||
|
if check.LinuxPath != nil {
|
||||||
|
d.LinuxPath = *check.LinuxPath
|
||||||
|
}
|
||||||
|
if check.DarwinPath != nil {
|
||||||
|
d.DarwinPath = *check.DarwinPath
|
||||||
|
}
|
||||||
|
if check.WindowsPath != nil {
|
||||||
|
d.WindowsPath = *check.WindowsPath
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func toDiskEncryptionCheckResponse(check *DiskEncryptionCheck) *api.DiskEncryptionCheck {
|
||||||
|
return &api.DiskEncryptionCheck{
|
||||||
|
LinuxPath: &check.LinuxPath,
|
||||||
|
DarwinPath: &check.DarwinPath,
|
||||||
|
WindowsPath: &check.WindowsPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
52
management/server/posture/disk_encryption.go
Normal file
52
management/server/posture/disk_encryption.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package posture
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DiskEncryptionCheck verifies that specified volumes are encrypted.
|
||||||
|
type DiskEncryptionCheck struct {
|
||||||
|
LinuxPath string
|
||||||
|
DarwinPath string
|
||||||
|
WindowsPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Check = (*DiskEncryptionCheck)(nil)
|
||||||
|
|
||||||
|
// Name returns the name of the check.
|
||||||
|
func (d *DiskEncryptionCheck) Name() string {
|
||||||
|
return DiskEncryptionCheckName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check performs the disk encryption verification for the given peer.
|
||||||
|
func (d *DiskEncryptionCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) {
|
||||||
|
var pathToCheck string
|
||||||
|
|
||||||
|
switch peer.Meta.GoOS {
|
||||||
|
case "linux":
|
||||||
|
pathToCheck = d.LinuxPath
|
||||||
|
case "darwin":
|
||||||
|
pathToCheck = d.DarwinPath
|
||||||
|
case "windows":
|
||||||
|
pathToCheck = d.WindowsPath
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if pathToCheck == "" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer.Meta.DiskEncryption.IsEncrypted(pathToCheck), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks the configuration of the disk encryption check.
|
||||||
|
func (d *DiskEncryptionCheck) Validate() error {
|
||||||
|
if d.LinuxPath == "" && d.DarwinPath == "" && d.WindowsPath == "" {
|
||||||
|
return fmt.Errorf("%s at least one path must be configured", d.Name())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
306
management/server/posture/disk_encryption_test.go
Normal file
306
management/server/posture/disk_encryption_test.go
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
package posture
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDiskEncryptionCheck_Check(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input peer.Peer
|
||||||
|
check DiskEncryptionCheck
|
||||||
|
wantErr bool
|
||||||
|
isValid bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "linux with encrypted root",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "/", Encrypted: true},
|
||||||
|
{Path: "/home", Encrypted: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "linux with unencrypted root",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "/", Encrypted: false},
|
||||||
|
{Path: "/home", Encrypted: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "linux with no volume info",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "darwin with encrypted root",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "darwin",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "/", Encrypted: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
DarwinPath: "/",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "darwin with unencrypted root",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "darwin",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "/", Encrypted: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
DarwinPath: "/",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "windows with encrypted C drive",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "windows",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "C:", Encrypted: true},
|
||||||
|
{Path: "D:", Encrypted: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
WindowsPath: "C:",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "windows with unencrypted C drive",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "windows",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "C:", Encrypted: false},
|
||||||
|
{Path: "D:", Encrypted: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
WindowsPath: "C:",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported ios operating system",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "ios",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
DarwinPath: "/",
|
||||||
|
WindowsPath: "C:",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported android operating system",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "android",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
DarwinPath: "/",
|
||||||
|
WindowsPath: "C:",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "linux peer with no linux path configured passes",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "/", Encrypted: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
DarwinPath: "/",
|
||||||
|
WindowsPath: "C:",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "darwin peer with no darwin path configured passes",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "darwin",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "/", Encrypted: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
WindowsPath: "C:",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "windows peer with no windows path configured passes",
|
||||||
|
input: peer.Peer{
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
GoOS: "windows",
|
||||||
|
DiskEncryption: peer.DiskEncryptionInfo{
|
||||||
|
Volumes: []peer.DiskEncryptionVolume{
|
||||||
|
{Path: "C:", Encrypted: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
DarwinPath: "/",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
isValid: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
isValid, err := tt.check.Check(context.Background(), tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.isValid, isValid)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiskEncryptionCheck_Validate(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
check DiskEncryptionCheck
|
||||||
|
expectedError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid linux, darwin and windows paths",
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
DarwinPath: "/",
|
||||||
|
WindowsPath: "C:",
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid linux path only",
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
LinuxPath: "/",
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid darwin path only",
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
DarwinPath: "/",
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid windows path only",
|
||||||
|
check: DiskEncryptionCheck{
|
||||||
|
WindowsPath: "C:",
|
||||||
|
},
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid empty paths",
|
||||||
|
check: DiskEncryptionCheck{},
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := tc.check.Validate()
|
||||||
|
if tc.expectedError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiskEncryptionCheck_Name(t *testing.T) {
|
||||||
|
check := DiskEncryptionCheck{}
|
||||||
|
assert.Equal(t, DiskEncryptionCheckName, check.Name())
|
||||||
|
}
|
||||||
@@ -27,6 +27,8 @@ import (
|
|||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -123,6 +125,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||||
|
&zones.Zone{}, &records.Record{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||||
@@ -4179,3 +4182,184 @@ func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingS
|
|||||||
|
|
||||||
return userID, nil
|
return userID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) CreateZone(ctx context.Context, zone *zones.Zone) error {
|
||||||
|
result := s.db.Create(zone)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to create zone to store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to create zone to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) UpdateZone(ctx context.Context, zone *zones.Zone) error {
|
||||||
|
result := s.db.Select("*").Save(zone)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to update zone to store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to update zone to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeleteZone(ctx context.Context, accountID, zoneID string) error {
|
||||||
|
result := s.db.Delete(&zones.Zone{}, accountAndIDQueryCondition, accountID, zoneID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete zone from store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete zone from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewZoneNotFoundError(zoneID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var zone *zones.Zone
|
||||||
|
result := tx.Preload("Records").Take(&zone, accountAndIDQueryCondition, accountID, zoneID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewZoneNotFoundError(zoneID)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf("failed to get zone from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get zone from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return zone, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error) {
|
||||||
|
var zone *zones.Zone
|
||||||
|
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&zone)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewZoneNotFoundError(domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf("failed to get zone by domain from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get zone by domain from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return zone, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var zones []*zones.Zone
|
||||||
|
result := tx.Preload("Records").Find(&zones, accountIDCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get zones from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get zones from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return zones, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) CreateDNSRecord(ctx context.Context, record *records.Record) error {
|
||||||
|
result := s.db.Create(record)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to create dns record to store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to create dns record to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) UpdateDNSRecord(ctx context.Context, record *records.Record) error {
|
||||||
|
result := s.db.Select("*").Save(record)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to update dns record to store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to update dns record to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error {
|
||||||
|
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete dns record from store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete dns record from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewDNSRecordNotFoundError(recordID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var record *records.Record
|
||||||
|
result := tx.Where("account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID).Take(&record)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewDNSRecordNotFoundError(recordID)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf("failed to get dns record from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get dns record from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var recordsList []*records.Record
|
||||||
|
result := tx.Where("account_id = ? AND zone_id = ?", accountID, zoneID).Find(&recordsList)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get zone dns records from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get zone dns records from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return recordsList, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var recordsList []*records.Record
|
||||||
|
result := tx.Where("account_id = ? AND zone_id = ? AND name = ?", accountID, zoneID, name).Find(&recordsList)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get zone dns records by name from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get zone dns records by name from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return recordsList, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error {
|
||||||
|
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ?", accountID, zoneID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete zone dns records from store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete zone dns records from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -4025,3 +4027,476 @@ func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
|
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_CreateZone(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, savedZone)
|
||||||
|
assert.Equal(t, zone.ID, savedZone.ID)
|
||||||
|
assert.Equal(t, zone.Name, savedZone.Name)
|
||||||
|
assert.Equal(t, zone.Domain, savedZone.Domain)
|
||||||
|
assert.Equal(t, zone.Enabled, savedZone.Enabled)
|
||||||
|
assert.Equal(t, zone.EnableSearchDomain, savedZone.EnableSearchDomain)
|
||||||
|
assert.Equal(t, zone.DistributionGroups, savedZone.DistributionGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetZoneByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
zoneID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing zone",
|
||||||
|
accountID: accountID,
|
||||||
|
zoneID: zone.ID,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing zone",
|
||||||
|
accountID: accountID,
|
||||||
|
zoneID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty zone ID",
|
||||||
|
accountID: accountID,
|
||||||
|
zoneID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, savedZone)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, savedZone)
|
||||||
|
assert.Equal(t, tt.zoneID, savedZone.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetAccountZones(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone1 := zones.NewZone(accountID, "Zone 1", "example1.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
zone2 := zones.NewZone(accountID, "Zone 2", "example2.com", true, true, []string{"group1", "group2"})
|
||||||
|
err = store.CreateZone(context.Background(), zone2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
allZones, err := store.GetAccountZones(context.Background(), LockingStrengthNone, accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, allZones)
|
||||||
|
assert.GreaterOrEqual(t, len(allZones), 2)
|
||||||
|
|
||||||
|
zoneIDs := make(map[string]bool)
|
||||||
|
for _, z := range allZones {
|
||||||
|
zoneIDs[z.ID] = true
|
||||||
|
}
|
||||||
|
assert.True(t, zoneIDs[zone1.ID])
|
||||||
|
assert.True(t, zoneIDs[zone2.ID])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetZoneByDomain(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
otherAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3c"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
domain string
|
||||||
|
expectError bool
|
||||||
|
errorType status.Type
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing zone by domain",
|
||||||
|
accountID: accountID,
|
||||||
|
domain: "example.com",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing zone domain",
|
||||||
|
accountID: accountID,
|
||||||
|
domain: "non-existing.com",
|
||||||
|
expectError: true,
|
||||||
|
errorType: status.NotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty domain",
|
||||||
|
accountID: accountID,
|
||||||
|
domain: "",
|
||||||
|
expectError: true,
|
||||||
|
errorType: status.NotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with different account ID",
|
||||||
|
accountID: otherAccountID,
|
||||||
|
domain: "example.com",
|
||||||
|
expectError: true,
|
||||||
|
errorType: status.NotFound,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
savedZone, err := store.GetZoneByDomain(context.Background(), tt.accountID, tt.domain)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, tt.errorType, sErr.Type())
|
||||||
|
require.Nil(t, savedZone)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, savedZone)
|
||||||
|
assert.Equal(t, tt.domain, savedZone.Domain)
|
||||||
|
assert.Equal(t, zone.ID, savedZone.ID)
|
||||||
|
assert.Equal(t, zone.Name, savedZone.Name)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_UpdateZone(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
zone.Name = "Updated Zone"
|
||||||
|
zone.Domain = "updated.com"
|
||||||
|
zone.Enabled = false
|
||||||
|
zone.EnableSearchDomain = true
|
||||||
|
zone.DistributionGroups = []string{"group2", "group3"}
|
||||||
|
|
||||||
|
err = store.UpdateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updatedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedZone)
|
||||||
|
assert.Equal(t, "Updated Zone", updatedZone.Name)
|
||||||
|
assert.Equal(t, "updated.com", updatedZone.Domain)
|
||||||
|
assert.False(t, updatedZone.Enabled)
|
||||||
|
assert.True(t, updatedZone.EnableSearchDomain)
|
||||||
|
assert.Equal(t, []string{"group2", "group3"}, updatedZone.DistributionGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteZone(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = store.DeleteZone(context.Background(), accountID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
deletedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, deletedZone)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_CreateDNSRecord(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, savedRecord)
|
||||||
|
assert.Equal(t, record.ID, savedRecord.ID)
|
||||||
|
assert.Equal(t, record.Name, savedRecord.Name)
|
||||||
|
assert.Equal(t, record.Type, savedRecord.Type)
|
||||||
|
assert.Equal(t, record.Content, savedRecord.Content)
|
||||||
|
assert.Equal(t, record.TTL, savedRecord.TTL)
|
||||||
|
assert.Equal(t, zone.ID, savedRecord.ZoneID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetDNSRecordByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
zoneID string
|
||||||
|
recordID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing record",
|
||||||
|
accountID: accountID,
|
||||||
|
zoneID: zone.ID,
|
||||||
|
recordID: record.ID,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing record",
|
||||||
|
accountID: accountID,
|
||||||
|
zoneID: zone.ID,
|
||||||
|
recordID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty record ID",
|
||||||
|
accountID: accountID,
|
||||||
|
zoneID: zone.ID,
|
||||||
|
recordID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID, tt.recordID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, savedRecord)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, savedRecord)
|
||||||
|
assert.Equal(t, tt.recordID, savedRecord.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetZoneDNSRecords(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recordA := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), recordA)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recordAAAA := records.NewRecord(accountID, zone.ID, "ipv6.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), recordAAAA)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recordCNAME := records.NewRecord(accountID, zone.ID, "alias.example.com", records.RecordTypeCNAME, "www.example.com", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), recordCNAME)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, allRecords)
|
||||||
|
assert.Equal(t, 3, len(allRecords))
|
||||||
|
|
||||||
|
recordIDs := make(map[string]bool)
|
||||||
|
for _, r := range allRecords {
|
||||||
|
recordIDs[r.ID] = true
|
||||||
|
}
|
||||||
|
assert.True(t, recordIDs[recordA.ID])
|
||||||
|
assert.True(t, recordIDs[recordAAAA.ID])
|
||||||
|
assert.True(t, recordIDs[recordCNAME.ID])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetZoneDNSRecordsByName(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record2 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record3 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record3)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recordsByName, err := store.GetZoneDNSRecordsByName(context.Background(), LockingStrengthNone, accountID, zone.ID, "www.example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, recordsByName)
|
||||||
|
assert.Equal(t, 2, len(recordsByName))
|
||||||
|
|
||||||
|
for _, r := range recordsByName {
|
||||||
|
assert.Equal(t, "www.example.com", r.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_UpdateDNSRecord(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record.Name = "api.example.com"
|
||||||
|
record.Content = "192.168.1.100"
|
||||||
|
record.TTL = 600
|
||||||
|
|
||||||
|
err = store.UpdateDNSRecord(context.Background(), record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updatedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedRecord)
|
||||||
|
assert.Equal(t, "api.example.com", updatedRecord.Name)
|
||||||
|
assert.Equal(t, "192.168.1.100", updatedRecord.Content)
|
||||||
|
assert.Equal(t, 600, updatedRecord.TTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteDNSRecord(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = store.DeleteDNSRecord(context.Background(), accountID, zone.ID, record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
deletedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, deletedRecord)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||||
|
err = store.CreateZone(context.Background(), zone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record2 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
|
||||||
|
err = store.CreateDNSRecord(context.Background(), record2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, len(allRecords))
|
||||||
|
|
||||||
|
err = store.DeleteZoneDNSRecords(context.Background(), accountID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
remainingRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, len(remainingRecords))
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/dns"
|
"github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/testutil"
|
"github.com/netbirdio/netbird/management/server/testutil"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -209,6 +211,21 @@ type Store interface {
|
|||||||
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
|
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
|
||||||
SetFieldEncrypt(enc *crypt.FieldEncrypt)
|
SetFieldEncrypt(enc *crypt.FieldEncrypt)
|
||||||
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
|
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
|
||||||
|
|
||||||
|
CreateZone(ctx context.Context, zone *zones.Zone) error
|
||||||
|
UpdateZone(ctx context.Context, zone *zones.Zone) error
|
||||||
|
DeleteZone(ctx context.Context, accountID, zoneID string) error
|
||||||
|
GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error)
|
||||||
|
GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error)
|
||||||
|
GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error)
|
||||||
|
|
||||||
|
CreateDNSRecord(ctx context.Context, record *records.Record) error
|
||||||
|
UpdateDNSRecord(ctx context.Context, record *records.Record) error
|
||||||
|
DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error
|
||||||
|
GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error)
|
||||||
|
GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error)
|
||||||
|
GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error)
|
||||||
|
DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -150,17 +152,16 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
|||||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
|
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
|
||||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||||
peerRoutesMembership := make(LookupMap)
|
peerRoutesMembership := make(LookupMap)
|
||||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
groupListMap := a.GetPeerGroups(peerID)
|
|
||||||
for _, peer := range aclPeers {
|
for _, peer := range aclPeers {
|
||||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
|
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
|
||||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||||
routes = append(routes, filteredRoutes...)
|
routes = append(routes, filteredRoutes...)
|
||||||
}
|
}
|
||||||
@@ -274,6 +275,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
peerID string,
|
peerID string,
|
||||||
peersCustomZone nbdns.CustomZone,
|
peersCustomZone nbdns.CustomZone,
|
||||||
|
accountZones []*zones.Zone,
|
||||||
validatedPeersMap map[string]struct{},
|
validatedPeersMap map[string]struct{},
|
||||||
resourcePolicies map[string][]*Policy,
|
resourcePolicies map[string][]*Policy,
|
||||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||||
@@ -294,6 +296,8 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peerGroups := a.GetPeerGroups(peerID)
|
||||||
|
|
||||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||||
// exclude expired peers
|
// exclude expired peers
|
||||||
var peersToConnect []*nbpeer.Peer
|
var peersToConnect []*nbpeer.Peer
|
||||||
@@ -307,7 +311,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
peersToConnect = append(peersToConnect, p)
|
peersToConnect = append(peersToConnect, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
|
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
|
||||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||||
@@ -323,6 +327,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
|
|
||||||
if dnsManagementStatus {
|
if dnsManagementStatus {
|
||||||
var zones []nbdns.CustomZone
|
var zones []nbdns.CustomZone
|
||||||
|
|
||||||
if peersCustomZone.Domain != "" {
|
if peersCustomZone.Domain != "" {
|
||||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||||
zones = append(zones, nbdns.CustomZone{
|
zones = append(zones, nbdns.CustomZone{
|
||||||
@@ -330,6 +335,10 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
Records: records,
|
Records: records,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||||
|
zones = append(zones, filteredAccountZones...)
|
||||||
|
|
||||||
dnsUpdate.CustomZones = zones
|
dnsUpdate.CustomZones = zones
|
||||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||||
}
|
}
|
||||||
@@ -1881,3 +1890,66 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
|
|||||||
|
|
||||||
return filteredRecords
|
return filteredRecords
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterPeerAppliedZones filters account zones based on the peer's group membership
|
||||||
|
func filterPeerAppliedZones(ctx context.Context, accountZones []*zones.Zone, peerGroups LookupMap) []nbdns.CustomZone {
|
||||||
|
var customZones []nbdns.CustomZone
|
||||||
|
|
||||||
|
if len(peerGroups) == 0 {
|
||||||
|
return customZones
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, zone := range accountZones {
|
||||||
|
if !zone.Enabled || len(zone.Records) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hasAccess := false
|
||||||
|
for _, distGroupID := range zone.DistributionGroups {
|
||||||
|
if _, found := peerGroups[distGroupID]; found {
|
||||||
|
hasAccess = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasAccess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
simpleRecords := make([]nbdns.SimpleRecord, 0, len(zone.Records))
|
||||||
|
for _, record := range zone.Records {
|
||||||
|
var recordType int
|
||||||
|
rData := record.Content
|
||||||
|
|
||||||
|
switch record.Type {
|
||||||
|
case records.RecordTypeA:
|
||||||
|
recordType = int(dns.TypeA)
|
||||||
|
case records.RecordTypeAAAA:
|
||||||
|
recordType = int(dns.TypeAAAA)
|
||||||
|
case records.RecordTypeCNAME:
|
||||||
|
recordType = int(dns.TypeCNAME)
|
||||||
|
rData = dns.Fqdn(record.Content)
|
||||||
|
default:
|
||||||
|
log.WithContext(ctx).Warnf("unknown DNS record type %s for record %s", record.Type, record.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
simpleRecords = append(simpleRecords, nbdns.SimpleRecord{
|
||||||
|
Name: dns.Fqdn(record.Name),
|
||||||
|
Type: recordType,
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: record.TTL,
|
||||||
|
RData: rData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
customZones = append(customZones, nbdns.CustomZone{
|
||||||
|
Domain: dns.Fqdn(zone.Domain),
|
||||||
|
Records: simpleRecords,
|
||||||
|
SearchDomainDisabled: !zone.EnableSearchDomain,
|
||||||
|
NonAuthoritative: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return customZones
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -1425,3 +1427,515 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_filterPeerAppliedZones(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountZones []*zones.Zone
|
||||||
|
peerGroups LookupMap
|
||||||
|
expected []nbdns.CustomZone
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty peer groups returns empty custom zones",
|
||||||
|
accountZones: []*zones.Zone{},
|
||||||
|
peerGroups: LookupMap{},
|
||||||
|
expected: []nbdns.CustomZone{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer has access to zone with A record",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "www.example.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "example.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.example.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer has access to zone with search domain enabled",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "internal.local",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: true,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "api.internal.local",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "10.0.0.1",
|
||||||
|
TTL: 600,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "internal.local.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "api.internal.local.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 600,
|
||||||
|
RData: "10.0.0.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer has no access to zone",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "private.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group2"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "secret.private.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled zone is filtered out",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "disabled.com",
|
||||||
|
Enabled: false,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "www.disabled.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zone with no records is filtered out",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "empty.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer has access via multiple groups",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "multi.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1", "group2", "group3"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "www.multi.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group2": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "multi.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.multi.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple zones with mixed access",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "allowed.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "www.allowed.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "zone2",
|
||||||
|
Domain: "denied.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group2"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record2",
|
||||||
|
Name: "www.denied.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.2",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "allowed.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.allowed.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zone with multiple record types",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "mixed.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "www.mixed.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "record2",
|
||||||
|
Name: "ipv6.mixed.com",
|
||||||
|
Type: records.RecordTypeAAAA,
|
||||||
|
Content: "2001:db8::1",
|
||||||
|
TTL: 600,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "record3",
|
||||||
|
Name: "alias.mixed.com",
|
||||||
|
Type: records.RecordTypeCNAME,
|
||||||
|
Content: "www.mixed.com",
|
||||||
|
TTL: 900,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "mixed.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.mixed.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "ipv6.mixed.com.",
|
||||||
|
Type: int(dns.TypeAAAA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 600,
|
||||||
|
RData: "2001:db8::1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "alias.mixed.com.",
|
||||||
|
Type: int(dns.TypeCNAME),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 900,
|
||||||
|
RData: "www.mixed.com.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple zones both accessible",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "first.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: true,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "www.first.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "zone2",
|
||||||
|
Domain: "second.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record2",
|
||||||
|
Name: "www.second.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.2",
|
||||||
|
TTL: 600,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "first.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.first.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Domain: "second.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.second.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 600,
|
||||||
|
RData: "192.168.1.2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zone with multiple records of same type",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "multi-a.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "www.multi-a.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "record2",
|
||||||
|
Name: "www.multi-a.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.2",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "multi-a.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.multi-a.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "www.multi-a.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peer in multiple groups accessing different zones",
|
||||||
|
accountZones: []*zones.Zone{
|
||||||
|
{
|
||||||
|
ID: "zone1",
|
||||||
|
Domain: "zone1.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group1"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record1",
|
||||||
|
Name: "www.zone1.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.1",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "zone2",
|
||||||
|
Domain: "zone2.com",
|
||||||
|
Enabled: true,
|
||||||
|
EnableSearchDomain: false,
|
||||||
|
DistributionGroups: []string{"group2"},
|
||||||
|
Records: []*records.Record{
|
||||||
|
{
|
||||||
|
ID: "record2",
|
||||||
|
Name: "www.zone2.com",
|
||||||
|
Type: records.RecordTypeA,
|
||||||
|
Content: "192.168.1.2",
|
||||||
|
TTL: 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peerGroups: LookupMap{"group1": struct{}{}, "group2": struct{}{}},
|
||||||
|
expected: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "zone1.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.zone1.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Domain: "zone2.com.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.zone2.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SearchDomainDisabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := filterPeerAppliedZones(ctx, tt.accountZones, tt.peerGroups)
|
||||||
|
require.Equal(t, len(tt.expected), len(result), "number of custom zones should match")
|
||||||
|
|
||||||
|
for i, expectedZone := range tt.expected {
|
||||||
|
assert.Equal(t, expectedZone.Domain, result[i].Domain, "domain should match")
|
||||||
|
assert.Equal(t, expectedZone.SearchDomainDisabled, result[i].SearchDomainDisabled, "search domain disabled flag should match")
|
||||||
|
assert.Equal(t, len(expectedZone.Records), len(result[i].Records), "number of records should match")
|
||||||
|
|
||||||
|
for j, expectedRecord := range expectedZone.Records {
|
||||||
|
assert.Equal(t, expectedRecord.Name, result[i].Records[j].Name, "record name should match")
|
||||||
|
assert.Equal(t, expectedRecord.Type, result[i].Records[j].Type, "record type should match")
|
||||||
|
assert.Equal(t, expectedRecord.Class, result[i].Records[j].Class, "record class should match")
|
||||||
|
assert.Equal(t, expectedRecord.TTL, result[i].Records[j].TTL, "record TTL should match")
|
||||||
|
assert.Equal(t, expectedRecord.RData, result[i].Records[j].RData, "record RData should match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
@@ -25,11 +26,12 @@ func (a *Account) GetPeerNetworkMapExp(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
peerID string,
|
peerID string,
|
||||||
peersCustomZone nbdns.CustomZone,
|
peersCustomZone nbdns.CustomZone,
|
||||||
|
accountZones []*zones.Zone,
|
||||||
validatedPeers map[string]struct{},
|
validatedPeers map[string]struct{},
|
||||||
metrics *telemetry.AccountManagerMetrics,
|
metrics *telemetry.AccountManagerMetrics,
|
||||||
) *NetworkMap {
|
) *NetworkMap {
|
||||||
a.initNetworkMapBuilder(validatedPeers)
|
a.initNetworkMapBuilder(validatedPeers)
|
||||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
|
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||||
|
|||||||
@@ -70,13 +70,13 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
normalizeAndSortNetworkMap(newNetworkMap)
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||||
@@ -115,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
|||||||
b.Run("old builder", func(b *testing.B) {
|
b.Run("old builder", func(b *testing.B) {
|
||||||
for range b.N {
|
for range b.N {
|
||||||
for _, peerID := range peerIDs {
|
for _, peerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -124,7 +124,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
|||||||
for range b.N {
|
for range b.N {
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||||
for _, peerID := range peerIDs {
|
for _, peerID := range peerIDs {
|
||||||
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil)
|
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -177,7 +177,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -185,7 +185,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
|||||||
err = builder.OnPeerAddedIncremental(account, newPeerID)
|
err = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||||
require.NoError(t, err, "error adding peer to cache")
|
require.NoError(t, err, "error adding peer to cache")
|
||||||
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
normalizeAndSortNetworkMap(newNetworkMap)
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||||
@@ -240,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
|||||||
b.Run("old builder after add", func(b *testing.B) {
|
b.Run("old builder after add", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -250,7 +250,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
|||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = builder.OnPeerAddedIncremental(account, newPeerID)
|
_ = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -317,7 +317,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -325,7 +325,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
|||||||
err = builder.OnPeerAddedIncremental(account, newRouterID)
|
err = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||||
require.NoError(t, err, "error adding router to cache")
|
require.NoError(t, err, "error adding router to cache")
|
||||||
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
normalizeAndSortNetworkMap(newNetworkMap)
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||||
@@ -402,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
|||||||
b.Run("old builder after add", func(b *testing.B) {
|
b.Run("old builder after add", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -412,7 +412,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
|||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = builder.OnPeerAddedIncremental(account, newRouterID)
|
_ = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -458,7 +458,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -466,7 +466,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
|||||||
err = builder.OnPeerDeleted(account, deletedPeerID)
|
err = builder.OnPeerDeleted(account, deletedPeerID)
|
||||||
require.NoError(t, err, "error deleting peer from cache")
|
require.NoError(t, err, "error deleting peer from cache")
|
||||||
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
normalizeAndSortNetworkMap(newNetworkMap)
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||||
@@ -537,7 +537,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
@@ -545,7 +545,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
|||||||
err = builder.OnPeerDeleted(account, deletedRouterID)
|
err = builder.OnPeerDeleted(account, deletedRouterID)
|
||||||
require.NoError(t, err, "error deleting routing peer from cache")
|
require.NoError(t, err, "error deleting routing peer from cache")
|
||||||
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
normalizeAndSortNetworkMap(newNetworkMap)
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||||
@@ -597,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
|||||||
b.Run("old builder after delete", func(b *testing.B) {
|
b.Run("old builder after delete", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -607,7 +607,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
|||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = builder.OnPeerDeleted(account, deletedPeerID)
|
_ = builder.OnPeerDeleted(account, deletedPeerID)
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -944,7 +944,7 @@ func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T
|
|||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
@@ -1033,7 +1034,7 @@ func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||||
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
|
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone,
|
||||||
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
|
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
|
||||||
) *NetworkMap {
|
) *NetworkMap {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@@ -1057,7 +1058,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
|||||||
return &NetworkMap{Network: account.Network.Copy()}
|
return &NetworkMap{Network: account.Network.Copy()}
|
||||||
}
|
}
|
||||||
|
|
||||||
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
|
nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, accountZones, validatedPeers)
|
||||||
|
|
||||||
if metrics != nil {
|
if metrics != nil {
|
||||||
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
||||||
@@ -1074,8 +1075,8 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *NetworkMapBuilder) assembleNetworkMap(
|
func (b *NetworkMapBuilder) assembleNetworkMap(
|
||||||
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
||||||
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
dnsConfig *nbdns.Config, sshView *PeerSSHView, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{},
|
||||||
) *NetworkMap {
|
) *NetworkMap {
|
||||||
|
|
||||||
var peersToConnect []*nbpeer.Peer
|
var peersToConnect []*nbpeer.Peer
|
||||||
@@ -1125,13 +1126,26 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
finalDNSConfig := *dnsConfig
|
finalDNSConfig := *dnsConfig
|
||||||
if finalDNSConfig.ServiceEnable && customZone.Domain != "" {
|
if finalDNSConfig.ServiceEnable {
|
||||||
var zones []nbdns.CustomZone
|
var zones []nbdns.CustomZone
|
||||||
records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers)
|
|
||||||
|
peerGroupsSlice := b.cache.peerToGroups[peer.ID]
|
||||||
|
peerGroups := make(LookupMap, len(peerGroupsSlice))
|
||||||
|
for _, groupID := range peerGroupsSlice {
|
||||||
|
peerGroups[groupID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peersCustomZone.Domain != "" {
|
||||||
|
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers)
|
||||||
zones = append(zones, nbdns.CustomZone{
|
zones = append(zones, nbdns.CustomZone{
|
||||||
Domain: customZone.Domain,
|
Domain: peersCustomZone.Domain,
|
||||||
Records: records,
|
Records: records,
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||||
|
zones = append(zones, filteredAccountZones...)
|
||||||
|
|
||||||
finalDNSConfig.CustomZones = zones
|
finalDNSConfig.CustomZones = zones
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -911,10 +911,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
|||||||
accountUsers := []*types.User{}
|
accountUsers := []*types.User{}
|
||||||
switch {
|
switch {
|
||||||
case allowed:
|
case allowed:
|
||||||
|
start := time.Now()
|
||||||
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
log.WithContext(ctx).Tracef("Got %d users from account %s after %s", len(accountUsers), accountID, time.Since(start))
|
||||||
case user != nil && user.AccountID == accountID:
|
case user != nil && user.AccountID == accountID:
|
||||||
accountUsers = append(accountUsers, user)
|
accountUsers = append(accountUsers, user)
|
||||||
default:
|
default:
|
||||||
@@ -933,23 +935,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
|||||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||||
usersFromIntegration := make([]*idp.UserData, 0)
|
usersFromIntegration := make([]*idp.UserData, 0)
|
||||||
|
filtered := make(map[string]*idp.UserData, len(accountUsers))
|
||||||
|
log.WithContext(ctx).Tracef("Querying users from IDP for account %s", accountID)
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
integrationKeys := make(map[string]struct{})
|
||||||
for _, user := range accountUsers {
|
for _, user := range accountUsers {
|
||||||
if user.Issued == types.UserIssuedIntegration {
|
if user.Issued == types.UserIssuedIntegration {
|
||||||
key := user.IntegrationReference.CacheKey(accountID, user.Id)
|
integrationKeys[user.IntegrationReference.CacheKey(accountID)] = struct{}{}
|
||||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err)
|
|
||||||
users[user.Id] = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
usersFromIntegration = append(usersFromIntegration, info)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !user.IsServiceUser {
|
if !user.IsServiceUser {
|
||||||
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for key := range integrationKeys {
|
||||||
|
usersData, err := am.externalCacheManager.GetUsers(am.ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("GetUsers from ExternalCache for key: %s, error: %s", key, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, ud := range usersData {
|
||||||
|
filtered[ud.ID] = ud
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ud := range filtered {
|
||||||
|
usersFromIntegration = append(usersFromIntegration, ud)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Tracef("Got user info from external cache after %s", time.Since(start))
|
||||||
|
start = time.Now()
|
||||||
queriedUsers, err = am.lookupCache(ctx, users, accountID)
|
queriedUsers, err = am.lookupCache(ctx, users, accountID)
|
||||||
|
log.WithContext(ctx).Tracef("Got user info from cache for %d users after %s", len(queriedUsers), time.Since(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1086,8 +1086,12 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
cacheManager := am.GetExternalCacheManager()
|
cacheManager := am.GetExternalCacheManager()
|
||||||
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
tud := &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}
|
||||||
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}, time.Minute)
|
cacheKeyUser := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||||
|
err = cacheManager.Set(context.Background(), cacheKeyUser, tud, time.Minute)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
cacheKeyAccount := externalUser.IntegrationReference.CacheKey(mockAccountID)
|
||||||
|
err = cacheManager.SetUsers(context.Background(), cacheKeyAccount, []*idp.UserData{tud}, time.Minute)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
|
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
|
import "regexp"
|
||||||
|
|
||||||
|
var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
|
||||||
|
|
||||||
// Difference returns the elements in `a` that aren't in `b`.
|
// Difference returns the elements in `a` that aren't in `b`.
|
||||||
func Difference(a, b []string) []string {
|
func Difference(a, b []string) []string {
|
||||||
mb := make(map[string]struct{}, len(b))
|
mb := make(map[string]struct{}, len(b))
|
||||||
@@ -50,3 +54,10 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsValidDomain(domain string) bool {
|
||||||
|
if domain == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return domainRegex.MatchString(domain)
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -22,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
"github.com/netbirdio/netbird/shared/relay/auth"
|
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||||
"github.com/netbirdio/netbird/signal/metrics"
|
"github.com/netbirdio/netbird/signal/metrics"
|
||||||
|
"github.com/netbirdio/netbird/stun"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,6 +45,10 @@ type Config struct {
|
|||||||
LogLevel string
|
LogLevel string
|
||||||
LogFile string
|
LogFile string
|
||||||
HealthcheckListenAddress string
|
HealthcheckListenAddress string
|
||||||
|
// STUN server configuration
|
||||||
|
EnableSTUN bool
|
||||||
|
STUNPorts []int
|
||||||
|
STUNLogLevel string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) Validate() error {
|
func (c Config) Validate() error {
|
||||||
@@ -52,6 +58,25 @@ func (c Config) Validate() error {
|
|||||||
if c.AuthSecret == "" {
|
if c.AuthSecret == "" {
|
||||||
return fmt.Errorf("auth secret is required")
|
return fmt.Errorf("auth secret is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate STUN configuration
|
||||||
|
if c.EnableSTUN {
|
||||||
|
if len(c.STUNPorts) == 0 {
|
||||||
|
return fmt.Errorf("--stun-ports is required when --enable-stun is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[int]bool)
|
||||||
|
for _, port := range c.STUNPorts {
|
||||||
|
if port <= 0 || port > 65535 {
|
||||||
|
return fmt.Errorf("invalid STUN port %d: must be between 1 and 65535", port)
|
||||||
|
}
|
||||||
|
if seen[port] {
|
||||||
|
return fmt.Errorf("duplicate STUN port %d", port)
|
||||||
|
}
|
||||||
|
seen[port] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,6 +116,9 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
||||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
||||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
|
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
|
||||||
|
rootCmd.PersistentFlags().BoolVar(&cobraConfig.EnableSTUN, "enable-stun", false, "enable embedded STUN server")
|
||||||
|
rootCmd.PersistentFlags().IntSliceVar(&cobraConfig.STUNPorts, "stun-ports", []int{3478}, "ports for the embedded STUN server (can be specified multiple times or comma-separated)")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.STUNLogLevel, "stun-log-level", "info", "log level for STUN server (panic, fatal, error, warn, info, debug, trace)")
|
||||||
|
|
||||||
setFlagsFromEnvVars(rootCmd)
|
setFlagsFromEnvVars(rootCmd)
|
||||||
}
|
}
|
||||||
@@ -119,21 +147,14 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to initialize log: %s", err)
|
return fmt.Errorf("failed to initialize log: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resource creation phase (fail fast before starting any goroutines)
|
||||||
|
|
||||||
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
|
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("setup metrics: %v", err)
|
log.Debugf("setup metrics: %v", err)
|
||||||
return fmt.Errorf("setup metrics: %v", err)
|
return fmt.Errorf("setup metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
|
||||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
log.Fatalf("Failed to start metrics server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
srvListenerCfg := server.ListenerConfig{
|
srvListenerCfg := server.ListenerConfig{
|
||||||
Address: cobraConfig.ListenAddress,
|
Address: cobraConfig.ListenAddress,
|
||||||
}
|
}
|
||||||
@@ -145,6 +166,12 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
srvListenerCfg.TLSConfig = tlsConfig
|
srvListenerCfg.TLSConfig = tlsConfig
|
||||||
|
|
||||||
|
// Create STUN listeners early to fail fast
|
||||||
|
stunListeners, err := createSTUNListeners()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
||||||
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||||
|
|
||||||
@@ -155,60 +182,145 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
TLSSupport: tlsSupport,
|
TLSSupport: tlsSupport,
|
||||||
}
|
}
|
||||||
|
|
||||||
srv, err := server.NewServer(cfg)
|
srv, err := createRelayServer(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to create relay server: %v", err)
|
cleanupSTUNListeners(stunListeners)
|
||||||
return fmt.Errorf("failed to create relay server: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hCfg := healthcheck.Config{
|
||||||
|
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
||||||
|
ServiceChecker: srv,
|
||||||
|
}
|
||||||
|
httpHealthcheck, err := createHealthCheck(hCfg)
|
||||||
|
if err != nil {
|
||||||
|
cleanupSTUNListeners(stunListeners)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var stunServer *stun.Server
|
||||||
|
if len(stunListeners) > 0 {
|
||||||
|
stunServer = stun.NewServer(stunListeners, cobraConfig.STUNLogLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start all servers (only after all resources are successfully created)
|
||||||
|
startServers(&wg, metricsServer, srv, srvListenerCfg, httpHealthcheck, stunServer)
|
||||||
|
|
||||||
|
waitForExitSignal()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = shutdownServers(ctx, metricsServer, srv, httpHealthcheck, stunServer)
|
||||||
|
wg.Wait()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func startServers(wg *sync.WaitGroup, metricsServer *metrics.Metrics, srv *server.Server, srvListenerCfg server.ListenerConfig, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||||
|
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.Fatalf("failed to start metrics server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
instanceURL := srv.InstanceURL()
|
instanceURL := srv.InstanceURL()
|
||||||
log.Infof("server will be available on: %s", instanceURL.String())
|
log.Infof("server will be available on: %s", instanceURL.String())
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := srv.Listen(srvListenerCfg); err != nil {
|
if err := srv.Listen(srvListenerCfg); err != nil {
|
||||||
log.Fatalf("failed to bind server: %s", err)
|
log.Fatalf("failed to bind relay server: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hCfg := healthcheck.Config{
|
|
||||||
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
|
||||||
ServiceChecker: srv,
|
|
||||||
}
|
|
||||||
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create healthcheck server: %v", err)
|
|
||||||
return fmt.Errorf("failed to create healthcheck server: %v", err)
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatalf("Failed to start healthcheck server: %v", err)
|
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// it will block until exit signal
|
if stunServer != nil {
|
||||||
waitForExitSignal()
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := stunServer.Listen(); err != nil {
|
||||||
|
if errors.Is(err, stun.ErrServerClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("STUN server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
func shutdownServers(ctx context.Context, metricsServer *metrics.Metrics, srv *server.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) error {
|
||||||
defer cancel()
|
var errs error
|
||||||
|
|
||||||
var shutDownErrors error
|
|
||||||
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
|
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if stunServer != nil {
|
||||||
|
if err := stunServer.Shutdown(); err != nil {
|
||||||
|
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
|
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("shutting down metrics server")
|
log.Infof("shutting down metrics server")
|
||||||
if err := metricsServer.Shutdown(ctx); err != nil {
|
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
|
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
return errs
|
||||||
return shutDownErrors
|
}
|
||||||
|
|
||||||
|
func createHealthCheck(hCfg healthcheck.Config) (*healthcheck.Server, error) {
|
||||||
|
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create healthcheck server: %v", err)
|
||||||
|
return nil, fmt.Errorf("failed to create healthcheck server: %v", err)
|
||||||
|
}
|
||||||
|
return httpHealthcheck, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createRelayServer(cfg server.Config) (*server.Server, error) {
|
||||||
|
srv, err := server.NewServer(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create relay server: %v", err)
|
||||||
|
}
|
||||||
|
return srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
|
||||||
|
for _, l := range stunListeners {
|
||||||
|
_ = l.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSTUNListeners() ([]*net.UDPConn, error) {
|
||||||
|
var stunListeners []*net.UDPConn
|
||||||
|
if cobraConfig.EnableSTUN {
|
||||||
|
for _, port := range cobraConfig.STUNPorts {
|
||||||
|
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
|
||||||
|
if err != nil {
|
||||||
|
// Close already opened listeners on failure
|
||||||
|
cleanupSTUNListeners(stunListeners)
|
||||||
|
log.Debugf("failed to create STUN listener on port %d: %v", port, err)
|
||||||
|
return nil, fmt.Errorf("failed to create STUN listener on port %d: %v", port, err)
|
||||||
|
}
|
||||||
|
stunListeners = append(stunListeners, listener)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stunListeners, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
|
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ pid="$(pgrep -x -f /usr/bin/netbird-ui || true)"
|
|||||||
if [ -n "${pid}" ]
|
if [ -n "${pid}" ]
|
||||||
then
|
then
|
||||||
uid="$(cat /proc/"${pid}"/loginuid)"
|
uid="$(cat /proc/"${pid}"/loginuid)"
|
||||||
|
# loginuid can be 4294967295 (-1) if not set, fall back to process uid
|
||||||
|
if [ "${uid}" = "4294967295" ] || [ "${uid}" = "-1" ]; then
|
||||||
|
uid="$(stat -c '%u' /proc/"${pid}")"
|
||||||
|
fi
|
||||||
username="$(id -nu "${uid}")"
|
username="$(id -nu "${uid}")"
|
||||||
# Only re-run if it was already running
|
# Only re-run if it was already running
|
||||||
pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1
|
pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user