mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 19:56:46 +00:00
Compare commits
14 Commits
ci-win-tes
...
feature/di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
279e96e6b1 | ||
|
|
245481f33b | ||
|
|
b352ab84c0 | ||
|
|
3ce5d6a4f8 | ||
|
|
4c2eb2af73 | ||
|
|
daf1449174 | ||
|
|
1ff7abe909 | ||
|
|
067c77e49e | ||
|
|
291e640b28 | ||
|
|
efb954b7d6 | ||
|
|
cac9326d3d | ||
|
|
520d9c66cf | ||
|
|
ff10498a8b | ||
|
|
00b747ad5d |
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.2
|
||||
FROM alpine:3.23.2
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -314,9 +314,8 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
|
||||
statusOutputString = overview.FullDetailSummary()
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
statusOutputString = outputInformationHolder.FullDetailSummary()
|
||||
case jsonFlag:
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.JSON()
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.YAML()
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -38,6 +39,7 @@ type Client struct {
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
recorder *peer.Status
|
||||
}
|
||||
|
||||
// Options configures a new Client.
|
||||
@@ -161,11 +163,17 @@ func New(opts Options) (*Client, error) {
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cancel != nil {
|
||||
if c.connect != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
|
||||
defer func() {
|
||||
if c.connect == nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
@@ -173,7 +181,9 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
c.recorder = recorder
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
@@ -197,6 +207,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
c.cancel = cancel
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -211,17 +222,23 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
c.cancel = nil
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
connect := c.connect
|
||||
go func() {
|
||||
done <- c.connect.Stop()
|
||||
done <- connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancel = nil
|
||||
c.connect = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.cancel = nil
|
||||
c.connect = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
@@ -315,6 +332,62 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
recorder := c.recorder
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return peer.FullStatus{}, errors.New("client not started")
|
||||
}
|
||||
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
_ = engine.RunHealthProbes(false)
|
||||
}
|
||||
}
|
||||
|
||||
return recorder.GetFullStatus(), nil
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncResp, err := engine.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get sync response: %w", err)
|
||||
}
|
||||
|
||||
return syncResp, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the client and its components.
|
||||
func (c *Client) SetLogLevel(levelStr string) error {
|
||||
level, err := logrus.ParseLevel(levelStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
|
||||
logrus.SetLevel(level)
|
||||
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect != nil {
|
||||
connect.SetLogLevel(level)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
|
||||
@@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
|
||||
return syncResponse, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager if the engine is running.
|
||||
func (c *ConnectClient) SetLogLevel(level log.Level) {
|
||||
engine := c.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager != nil {
|
||||
fwManager.SetLogLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
|
||||
const (
|
||||
PriorityMgmtCache = 150
|
||||
PriorityLocal = 100
|
||||
PriorityDNSRoute = 75
|
||||
PriorityDNSRoute = 100
|
||||
PriorityLocal = 75
|
||||
PriorityUpstream = 50
|
||||
PriorityDefault = 1
|
||||
PriorityFallback = -100
|
||||
|
||||
@@ -631,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
@@ -743,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
domainGroup.domain,
|
||||
@@ -926,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
||||
return configurer.WGStats{}, nil
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
@@ -2047,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
|
||||
|
||||
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
// Test that priority constants are ordered correctly
|
||||
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
||||
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
|
||||
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
@@ -418,6 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
|
||||
conn, err := nsNet.DialContext(ctx, network, upstream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("with %s: %w", network, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close DNS connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
}
|
||||
|
||||
reply, err := dnsConn.ReadMsg()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s message: %w", network, err)
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
|
||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||
func FormatPeerStatus(peerState *peer.State) string {
|
||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||
|
||||
@@ -23,9 +23,7 @@ type upstreamResolver struct {
|
||||
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
_ WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
|
||||
@@ -5,22 +5,23 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
*upstreamResolverBase
|
||||
nsNet *netstack.Net
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -28,12 +29,23 @@ func newUpstreamResolver(
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
nsNet: wgIface.GetNet(),
|
||||
}
|
||||
upstreamResolverBase.upstreamClient = nonIOS
|
||||
return nonIOS, nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
|
||||
// Similar to iOS logic, we should determine if the DNS server is reachable directly
|
||||
// or needs to go through the tunnel, and only use netstack when necessary.
|
||||
// For now, only use netstack on JS platform where direct access is not possible.
|
||||
if u.nsNet != nil && runtime.GOOS == "js" {
|
||||
start := time.Now()
|
||||
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
|
||||
return reply, time.Since(start), err
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
|
||||
@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
interfaceName string,
|
||||
ip netip.Addr,
|
||||
net netip.Prefix,
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -37,9 +35,9 @@ func newUpstreamResolver(
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: ip,
|
||||
lNet: net,
|
||||
interfaceName: interfaceName,
|
||||
lIP: wgIface.Address().IP,
|
||||
lNet: wgIface.Address().Network,
|
||||
interfaceName: wgIface.Name(),
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
|
||||
@@ -2,13 +2,17 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
)
|
||||
|
||||
@@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
|
||||
// Convert test servers to netip.AddrPort
|
||||
var servers []netip.AddrPort
|
||||
for _, server := range testCase.InputServers {
|
||||
@@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockNetstackProvider struct{}
|
||||
|
||||
func (m *mockNetstackProvider) Name() string { return "mock" }
|
||||
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
|
||||
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
|
||||
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
|
||||
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
|
||||
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
|
||||
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
|
||||
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUpstreamResolver struct {
|
||||
r *dns.Msg
|
||||
rtt time.Duration
|
||||
|
||||
@@ -5,6 +5,8 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -17,4 +19,5 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -12,5 +14,6 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -1748,6 +1748,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
// Skip STUN/TURN probing for JS/WASM as it's not available
|
||||
relayHealthy := true
|
||||
if runtime.GOOS != "js" {
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
@@ -1756,7 +1760,6 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
relayHealthy := true
|
||||
for _, res := range results {
|
||||
if res.Err != nil {
|
||||
relayHealthy = false
|
||||
@@ -1764,6 +1767,7 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
}
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
}
|
||||
|
||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||
|
||||
@@ -72,9 +72,16 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
|
||||
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
|
||||
log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)
|
||||
|
||||
jwtConfig := &sshserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audience: protoJWT.GetAudience(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -158,6 +159,7 @@ type FullStatus struct {
|
||||
NSGroupStates []NSGroupState
|
||||
NumOfForwardingRules int
|
||||
LazyConnectionEnabled bool
|
||||
Events []*proto.SystemEvent
|
||||
}
|
||||
|
||||
type StatusChangeSubscription struct {
|
||||
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
}
|
||||
|
||||
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
|
||||
fullStatus.Events = d.GetEventHistory()
|
||||
return fullStatus
|
||||
}
|
||||
|
||||
@@ -1181,3 +1184,97 @@ type EventSubscription struct {
|
||||
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
|
||||
return s.events
|
||||
}
|
||||
|
||||
// ToProto converts FullStatus to proto.FullStatus.
|
||||
func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
|
||||
if err := fs.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fs.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
|
||||
if err := fs.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
|
||||
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
|
||||
|
||||
for _, peerState := range fs.Peers {
|
||||
networks := maps.Keys(peerState.GetRoutes())
|
||||
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: networks,
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fs.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fs.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
pbFullStatus.Events = fs.Events
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
@@ -17,13 +17,13 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -38,11 +38,6 @@ type internalDNATer interface {
|
||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||
}
|
||||
|
||||
type wgInterface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
type DnsInterceptor struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
@@ -52,7 +47,7 @@ type DnsInterceptor struct {
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
interceptedDomains domainMap
|
||||
wgInterface wgInterface
|
||||
wgInterface iface.WGIface
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
@@ -250,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if err != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if r.Extra == nil {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
@@ -264,20 +253,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
|
||||
if reply == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
|
||||
return
|
||||
}
|
||||
|
||||
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
|
||||
// Returns the DNS reply on success, or nil on error (error responses are written internally).
|
||||
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
|
||||
startTime := time.Now()
|
||||
|
||||
nsNet := d.wgInterface.GetNet()
|
||||
var reply *dns.Msg
|
||||
var err error
|
||||
|
||||
if nsNet != nil {
|
||||
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
|
||||
} else {
|
||||
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if clientErr != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
|
||||
return nil
|
||||
}
|
||||
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return reply
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
|
||||
if d.statusRecorder == nil {
|
||||
return ""
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -173,20 +173,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
|
||||
log.SetLevel(level)
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("connect client not initialized")
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetLogLevel(level)
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager == nil {
|
||||
return nil, fmt.Errorf("firewall manager not initialized")
|
||||
}
|
||||
|
||||
fwManager.SetLogLevel(level)
|
||||
|
||||
log.Infof("Log level set to %s", level.String())
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
|
||||
events := s.statusRecorder.GetEventHistory()
|
||||
return &proto.GetEventsResponse{Events: events}, nil
|
||||
}
|
||||
|
||||
@@ -13,15 +13,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
@@ -1067,11 +1064,9 @@ func (s *Server) Status(
|
||||
if msg.GetFullPeerStatus {
|
||||
s.runProbes(msg.ShouldRunProbes)
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1600,94 +1595,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
||||
if err := fullStatus.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
||||
if err := fullStatus.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fullStatus.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fullStatus.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
// sendTerminalNotification sends a terminal notification message
|
||||
// to inform the user that the NetBird connection session has expired.
|
||||
func sendTerminalNotification() error {
|
||||
|
||||
@@ -132,7 +132,7 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
Audiences: []string{"test-audience"},
|
||||
KeysLocation: "test-keys",
|
||||
}
|
||||
serverConfig := &Config{
|
||||
@@ -202,7 +202,7 @@ func TestJWTDetection(t *testing.T) {
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
@@ -329,7 +329,7 @@ func TestJWTFailClose(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
MaxTokenAge: 3600,
|
||||
}
|
||||
@@ -567,7 +567,7 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
@@ -646,3 +646,108 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTMultipleAudiences tests JWT validation with multiple audiences (dashboard and CLI).
|
||||
func TestJWTMultipleAudiences(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT multiple audiences tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
dashboardAudience = "dashboard-audience"
|
||||
cliAudience = "cli-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
audience string
|
||||
wantAuthOK bool
|
||||
}{
|
||||
{
|
||||
name: "accepts_dashboard_audience",
|
||||
audience: dashboardAudience,
|
||||
wantAuthOK: true,
|
||||
},
|
||||
{
|
||||
name: "accepts_cli_audience",
|
||||
audience: cliAudience,
|
||||
wantAuthOK: true,
|
||||
},
|
||||
{
|
||||
name: "rejects_unknown_audience",
|
||||
audience: "unknown-audience",
|
||||
wantAuthOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{dashboardAudience, cliAudience},
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
testUserHash, err := sshuserhash.HashUserID("test-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
currentUser: {0},
|
||||
},
|
||||
}
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := generateValidJWT(t, privateKey, issuer, tc.audience)
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.Password(token),
|
||||
},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if tc.wantAuthOK {
|
||||
require.NoError(t, err, "JWT authentication should succeed for audience %s", tc.audience)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
err = session.Shell()
|
||||
require.NoError(t, err, "Shell should work with valid audience")
|
||||
} else {
|
||||
assert.Error(t, err, "JWT authentication should fail for unknown audience")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,9 +176,9 @@ type Server struct {
|
||||
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
Audiences []string
|
||||
}
|
||||
|
||||
// Config contains all SSH server configuration options
|
||||
@@ -427,18 +427,21 @@ func (s *Server) ensureJWTValidator() error {
|
||||
return fmt.Errorf("JWT config not set")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
||||
if len(config.Audiences) == 0 {
|
||||
return fmt.Errorf("JWT config has no audiences configured")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audiences: %v)", config.Issuer, config.Audiences)
|
||||
validator := jwt.NewValidator(
|
||||
config.Issuer,
|
||||
[]string{config.Audience},
|
||||
config.Audiences,
|
||||
config.KeysLocation,
|
||||
true,
|
||||
)
|
||||
|
||||
// Use custom userIDClaim from authorizer if available
|
||||
extractorOptions := []jwt.ClaimsExtractorOption{
|
||||
jwt.WithAudience(config.Audience),
|
||||
jwt.WithAudience(config.Audiences[0]),
|
||||
}
|
||||
if authorizer.GetUserIDClaim() != "" {
|
||||
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
|
||||
@@ -475,8 +478,8 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
||||
if err != nil {
|
||||
if jwtConfig != nil {
|
||||
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audiences=%v, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audiences, claims["iss"], claims["aud"], err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("validate token: %w", err)
|
||||
|
||||
@@ -325,61 +325,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
|
||||
}
|
||||
}
|
||||
|
||||
func ParseToJSON(overview OutputOverview) (string, error) {
|
||||
jsonBytes, err := json.Marshal(overview)
|
||||
// JSON returns the status overview as a JSON string.
|
||||
func (o *OutputOverview) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
func ParseToYAML(overview OutputOverview) (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(overview)
|
||||
// YAML returns the status overview as a YAML string.
|
||||
func (o *OutputOverview) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
// GeneralSummary returns a general summary of the status overview.
|
||||
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
if o.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
if o.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
var signalConnString string
|
||||
if overview.SignalState.Connected {
|
||||
if o.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
if o.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
interfaceIP := overview.IP
|
||||
if overview.KernelInterface {
|
||||
interfaceIP := o.IP
|
||||
if o.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if overview.IP == "" {
|
||||
} else if o.IP == "" {
|
||||
interfaceTypeString = "N/A"
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
for _, relay := range o.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
|
||||
@@ -395,18 +398,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(overview.Networks) > 0 {
|
||||
sort.Strings(overview.Networks)
|
||||
networks = strings.Join(overview.Networks, ", ")
|
||||
if len(o.Networks) > 0 {
|
||||
sort.Strings(o.Networks)
|
||||
networks = strings.Join(o.Networks, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
for _, nsServerGroup := range o.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
@@ -430,25 +433,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
if o.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if overview.RosenpassPermissive {
|
||||
if o.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
lazyConnectionEnabledStatus := "false"
|
||||
if overview.LazyConnectionEnabled {
|
||||
if o.LazyConnectionEnabled {
|
||||
lazyConnectionEnabledStatus = "true"
|
||||
}
|
||||
|
||||
sshServerStatus := "Disabled"
|
||||
if overview.SSHServerState.Enabled {
|
||||
sessionCount := len(overview.SSHServerState.Sessions)
|
||||
if o.SSHServerState.Enabled {
|
||||
sessionCount := len(o.SSHServerState.Sessions)
|
||||
if sessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if sessionCount > 1 {
|
||||
@@ -460,7 +463,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
}
|
||||
|
||||
if showSSHSessions && sessionCount > 0 {
|
||||
for _, session := range overview.SSHServerState.Sessions {
|
||||
for _, session := range o.SSHServerState.Sessions {
|
||||
var sessionDisplay string
|
||||
if session.JWTUsername != "" {
|
||||
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
|
||||
@@ -484,7 +487,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
@@ -512,30 +515,31 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
"Forwarding rules: %d\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
o.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
overview.ProfileName,
|
||||
o.ProfileName,
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
domain.Domain(overview.FQDN).SafeString(),
|
||||
domain.Domain(o.FQDN).SafeString(),
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
networks,
|
||||
overview.NumberOfForwardingRules,
|
||||
o.NumberOfForwardingRules,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
func ParseToFullDetailSummary(overview OutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(overview.Events)
|
||||
summary := ParseGeneralSummary(overview, true, true, true, true)
|
||||
// FullDetailSummary returns a full detailed summary with peer details and events.
|
||||
func (o *OutputOverview) FullDetailSummary() string {
|
||||
parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(o.Events)
|
||||
summary := o.GeneralSummary(true, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
|
||||
@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
jsonString, _ := ParseToJSON(overview)
|
||||
jsonString, _ := overview.JSON()
|
||||
|
||||
//@formatter:off
|
||||
expectedJSONString := `
|
||||
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := ParseToYAML(overview)
|
||||
yaml, _ := overview.YAML()
|
||||
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := ParseToFullDetailSummary(overview)
|
||||
detail := overview.FullDetailSummary()
|
||||
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
|
||||
shortVersion := overview.GeneralSummary(false, false, false, false)
|
||||
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
|
||||
22
client/system/disk_encryption.go
Normal file
22
client/system/disk_encryption.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package system
|
||||
|
||||
// DiskEncryptionVolume represents encryption status of a single volume.
|
||||
type DiskEncryptionVolume struct {
|
||||
Path string
|
||||
Encrypted bool
|
||||
}
|
||||
|
||||
// DiskEncryptionInfo holds disk encryption detection results.
|
||||
type DiskEncryptionInfo struct {
|
||||
Volumes []DiskEncryptionVolume
|
||||
}
|
||||
|
||||
// IsEncrypted returns true if the volume at the given path is encrypted.
|
||||
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
|
||||
for _, v := range d.Volumes {
|
||||
if v.Path == path {
|
||||
return v.Encrypted
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
35
client/system/disk_encryption_darwin.go
Normal file
35
client/system/disk_encryption_darwin.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// detectDiskEncryption detects FileVault encryption status on macOS.
|
||||
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
|
||||
info := DiskEncryptionInfo{}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "fdesetup", "status")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Debugf("execute fdesetup: %v", err)
|
||||
return info
|
||||
}
|
||||
|
||||
encrypted := strings.Contains(string(output), "FileVault is On")
|
||||
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
|
||||
Path: "/",
|
||||
Encrypted: encrypted,
|
||||
})
|
||||
|
||||
return info
|
||||
}
|
||||
98
client/system/disk_encryption_linux.go
Normal file
98
client/system/disk_encryption_linux.go
Normal file
@@ -0,0 +1,98 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// detectDiskEncryption detects LUKS encryption status on Linux by reading sysfs.
|
||||
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
|
||||
info := DiskEncryptionInfo{}
|
||||
|
||||
encryptedDevices := findEncryptedDevices()
|
||||
mountPoints := parseMounts(encryptedDevices)
|
||||
|
||||
info.Volumes = mountPoints
|
||||
return info
|
||||
}
|
||||
|
||||
// findEncryptedDevices scans /sys/block for dm-crypt (LUKS) encrypted devices.
|
||||
func findEncryptedDevices() map[string]bool {
|
||||
encryptedDevices := make(map[string]bool)
|
||||
|
||||
sysBlock := "/sys/block"
|
||||
entries, err := os.ReadDir(sysBlock)
|
||||
if err != nil {
|
||||
log.Debugf("read /sys/block: %v", err)
|
||||
return encryptedDevices
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
dmUuidPath := filepath.Join(sysBlock, entry.Name(), "dm", "uuid")
|
||||
data, err := os.ReadFile(dmUuidPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
uuid := strings.TrimSpace(string(data))
|
||||
if strings.HasPrefix(uuid, "CRYPT-") {
|
||||
dmNamePath := filepath.Join(sysBlock, entry.Name(), "dm", "name")
|
||||
if nameData, err := os.ReadFile(dmNamePath); err == nil {
|
||||
dmName := strings.TrimSpace(string(nameData))
|
||||
encryptedDevices["/dev/mapper/"+dmName] = true
|
||||
}
|
||||
encryptedDevices["/dev/"+entry.Name()] = true
|
||||
}
|
||||
}
|
||||
|
||||
return encryptedDevices
|
||||
}
|
||||
|
||||
// parseMounts reads /proc/mounts and maps devices to mount points with encryption status.
|
||||
func parseMounts(encryptedDevices map[string]bool) []DiskEncryptionVolume {
|
||||
var volumes []DiskEncryptionVolume
|
||||
|
||||
mountsFile, err := os.Open("/proc/mounts")
|
||||
if err != nil {
|
||||
log.Debugf("open /proc/mounts: %v", err)
|
||||
return volumes
|
||||
}
|
||||
defer func() {
|
||||
if err := mountsFile.Close(); err != nil {
|
||||
log.Debugf("close /proc/mounts: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(mountsFile)
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
device, mountPoint := fields[0], fields[1]
|
||||
|
||||
encrypted := encryptedDevices[device]
|
||||
|
||||
if !encrypted && strings.HasPrefix(device, "/dev/mapper/") {
|
||||
for encDev := range encryptedDevices {
|
||||
if device == encDev {
|
||||
encrypted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
volumes = append(volumes, DiskEncryptionVolume{
|
||||
Path: mountPoint,
|
||||
Encrypted: encrypted,
|
||||
})
|
||||
}
|
||||
|
||||
return volumes
|
||||
}
|
||||
10
client/system/disk_encryption_stub.go
Normal file
10
client/system/disk_encryption_stub.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build android || ios || freebsd || js
|
||||
|
||||
package system
|
||||
|
||||
import "context"
|
||||
|
||||
// detectDiskEncryption is a stub for unsupported platforms.
|
||||
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
|
||||
return DiskEncryptionInfo{}
|
||||
}
|
||||
41
client/system/disk_encryption_windows.go
Normal file
41
client/system/disk_encryption_windows.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build windows
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
)
|
||||
|
||||
// Win32EncryptableVolume represents the WMI class for BitLocker status.
|
||||
type Win32EncryptableVolume struct {
|
||||
DriveLetter string
|
||||
ProtectionStatus uint32
|
||||
}
|
||||
|
||||
// detectDiskEncryption detects BitLocker encryption status on Windows via WMI.
|
||||
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
|
||||
info := DiskEncryptionInfo{}
|
||||
|
||||
var volumes []Win32EncryptableVolume
|
||||
query := "SELECT DriveLetter, ProtectionStatus FROM Win32_EncryptableVolume"
|
||||
|
||||
err := wmi.QueryNamespace(query, &volumes, `root\CIMV2\Security\MicrosoftVolumeEncryption`)
|
||||
if err != nil {
|
||||
log.Debugf("query BitLocker status: %v", err)
|
||||
return info
|
||||
}
|
||||
|
||||
for _, vol := range volumes {
|
||||
driveLetter := strings.TrimSuffix(vol.DriveLetter, "\\")
|
||||
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
|
||||
Path: driveLetter,
|
||||
Encrypted: vol.ProtectionStatus == 1,
|
||||
})
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
@@ -59,6 +59,7 @@ type Info struct {
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
Files []File // for posture checks
|
||||
DiskEncryption DiskEncryptionInfo
|
||||
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
|
||||
@@ -44,6 +44,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
SystemSerialNumber: serial(),
|
||||
SystemProductName: productModel(),
|
||||
SystemManufacturer: productManufacturer(),
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
return gio
|
||||
|
||||
@@ -62,6 +62,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
|
||||
@@ -55,6 +55,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
Environment: env,
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
sysName := extractOsName(ctx, "sysName")
|
||||
swVersion := extractOsVersion(ctx, "swVersion")
|
||||
|
||||
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion}
|
||||
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion, DiskEncryption: detectDiskEncryption(ctx)}
|
||||
gio.Hostname = extractDeviceName(ctx, "hostname")
|
||||
gio.NetbirdVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
@@ -15,7 +15,7 @@ func UpdateStaticInfoAsync() {
|
||||
}
|
||||
|
||||
// GetInfo retrieves system information for WASM environment
|
||||
func GetInfo(_ context.Context) *Info {
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
info := &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: runtime.GOARCH,
|
||||
@@ -25,6 +25,7 @@ func GetInfo(_ context.Context) *Info {
|
||||
Hostname: "wasm-client",
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
collectBrowserInfo(info)
|
||||
|
||||
@@ -73,6 +73,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
return gio
|
||||
|
||||
@@ -35,6 +35,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
DiskEncryption: detectDiskEncryption(ctx),
|
||||
}
|
||||
|
||||
addrs, err := networkAddresses()
|
||||
|
||||
@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
postUpStatusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
||||
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
preDownStatusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
time.Now().Format(time.RFC3339), params.duration)
|
||||
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
statusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
|
||||
request := &proto.DebugBundleRequest{
|
||||
|
||||
@@ -9,20 +9,29 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
pingBufferSize = 1500
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -113,18 +122,45 @@ func createStopMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// validateSSHArgs validates SSH connection arguments
|
||||
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
|
||||
if len(args) < 2 {
|
||||
return "", 0, "", js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return "", 0, "", js.ValueOf("host parameter must be a string")
|
||||
}
|
||||
if args[1].Type() != js.TypeNumber {
|
||||
return "", 0, "", js.ValueOf("port parameter must be a number")
|
||||
}
|
||||
|
||||
host = args[0].String()
|
||||
port = args[1].Int()
|
||||
username = "root"
|
||||
|
||||
if len(args) > 2 {
|
||||
if args[2].Type() == js.TypeString && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
} else if args[2].Type() != js.TypeString {
|
||||
return "", 0, "", js.ValueOf("username parameter must be a string")
|
||||
}
|
||||
}
|
||||
|
||||
return host, port, username, js.Undefined()
|
||||
}
|
||||
|
||||
// createSSHMethod creates the SSH connection method
|
||||
func createSSHMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: requires host and port")
|
||||
host, port, username, validationErr := validateSSHArgs(args)
|
||||
if !validationErr.IsUndefined() {
|
||||
if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
|
||||
return validationErr
|
||||
}
|
||||
|
||||
host := args[0].String()
|
||||
port := args[1].Int()
|
||||
username := "root"
|
||||
if len(args) > 2 && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(validationErr)
|
||||
})
|
||||
}
|
||||
|
||||
var jwtToken string
|
||||
@@ -154,6 +190,110 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
func performPing(client *netbird.Client, hostname string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "ping", hostname)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close ping connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
icmpData := make([]byte, 8)
|
||||
icmpData[0] = icmpEchoRequest
|
||||
icmpData[1] = icmpCodeEcho
|
||||
|
||||
if _, err := conn.Write(icmpData); err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, pingBufferSize)
|
||||
if _, err := conn.Read(buf); err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
|
||||
}
|
||||
|
||||
func performPingTCP(client *netbird.Client, hostname string, port int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
|
||||
return
|
||||
}
|
||||
latency := time.Since(start)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close TCP connection: %v", err)
|
||||
}
|
||||
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
|
||||
}
|
||||
|
||||
// createPingMethod creates the ping method
|
||||
func createPingMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: hostname required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPing(client, hostname)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createPingTCPMethod creates the pingtcp method
|
||||
func createPingTCPMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
if args[1].Type() != js.TypeNumber {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("port parameter must be a number"))
|
||||
})
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
port := args[1].Int()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPingTCP(client, hostname, port)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createProxyRequestMethod creates the proxyRequest method
|
||||
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
@@ -162,6 +302,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
|
||||
request := args[0]
|
||||
if request.Type() != js.TypeObject {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("request parameter must be an object"))
|
||||
})
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
response, err := http.ProxyRequest(client, request)
|
||||
@@ -181,11 +326,145 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
if args[1].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("port parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
proxy := rdp.NewRDCleanPathProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
// getStatusOverview is a helper to get the status overview
|
||||
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
|
||||
fullStatus, err := client.Status()
|
||||
if err != nil {
|
||||
return nbstatus.OutputOverview{}, err
|
||||
}
|
||||
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
statusResp := &proto.StatusResponse{
|
||||
DaemonVersion: version.NetbirdVersion(),
|
||||
FullStatus: pbFullStatus,
|
||||
}
|
||||
|
||||
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
|
||||
}
|
||||
|
||||
// createStatusMethod creates the status method that returns JSON
|
||||
func createStatusMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
jsonStr, err := overview.JSON()
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
|
||||
resolve.Invoke(jsonObj)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStatusSummaryMethod creates the statusSummary method
|
||||
func createStatusSummaryMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
summary := overview.GeneralSummary(false, false, false, false)
|
||||
js.Global().Get("console").Call("log", summary)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStatusDetailMethod creates the statusDetail method
|
||||
func createStatusDetailMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
detail := overview.FullDetailSummary()
|
||||
js.Global().Get("console").Call("log", detail)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
|
||||
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
syncResp, err := client.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
options := protojson.MarshalOptions{
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
AllowPartial: true,
|
||||
}
|
||||
jsonBytes, err := options.Marshal(syncResp)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
|
||||
resolve.Invoke(jsonObj)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
|
||||
func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: log level required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("log level parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
logLevel := args[0].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
if err := client.SetLogLevel(logLevel); err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
|
||||
return
|
||||
}
|
||||
log.Infof("Log level set to: %s", logLevel)
|
||||
resolve.Invoke(js.ValueOf(true))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
@@ -237,17 +516,24 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
|
||||
obj["start"] = createStartMethod(client)
|
||||
obj["stop"] = createStopMethod(client)
|
||||
obj["ping"] = createPingMethod(client)
|
||||
obj["pingtcp"] = createPingTCPMethod(client)
|
||||
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
|
||||
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
|
||||
8
go.mod
8
go.mod
@@ -78,8 +78,8 @@ require (
|
||||
github.com/pion/logging v0.2.4
|
||||
github.com/pion/randutil v0.1.0
|
||||
github.com/pion/stun/v2 v2.0.0
|
||||
github.com/pion/stun/v3 v3.0.0
|
||||
github.com/pion/transport/v3 v3.0.7
|
||||
github.com/pion/stun/v3 v3.1.0
|
||||
github.com/pion/transport/v3 v3.1.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/pkg/sftp v1.13.9
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
@@ -241,7 +241,7 @@ require (
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.7 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
github.com/pion/transport/v2 v2.2.4 // indirect
|
||||
github.com/pion/turn/v4 v4.1.1 // indirect
|
||||
@@ -263,7 +263,7 @@ require (
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.3 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
|
||||
16
go.sum
16
go.sum
@@ -444,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
|
||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
||||
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
||||
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
|
||||
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
|
||||
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
@@ -455,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
|
||||
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
|
||||
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
|
||||
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
|
||||
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
|
||||
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
|
||||
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
||||
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
|
||||
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
||||
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
|
||||
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
|
||||
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||
@@ -574,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
|
||||
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
|
||||
@@ -798,15 +798,15 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"insecureEnableGroups": true,
|
||||
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
|
||||
"insecureSkipEmailVerified": true,
|
||||
}
|
||||
switch cfg.Type {
|
||||
case "zitadel":
|
||||
oidcConfig["getUserInfo"] = true
|
||||
case "entra":
|
||||
oidcConfig["insecureSkipEmailVerified"] = true
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
case "okta":
|
||||
oidcConfig["insecureSkipEmailVerified"] = true
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
@@ -197,6 +198,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
@@ -223,9 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
@@ -318,7 +325,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
@@ -335,12 +342,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -434,7 +447,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
@@ -445,11 +465,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -472,7 +492,8 @@ func (c *Controller) getPeerNetworkMapExp(
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
customZone nbdns.CustomZone,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||
@@ -483,7 +504,7 @@ func (c *Controller) getPeerNetworkMapExp(
|
||||
}
|
||||
}
|
||||
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||
@@ -798,7 +819,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
@@ -809,11 +838,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -14,6 +15,7 @@ type Repository interface {
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
|
||||
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
|
||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
|
||||
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
13
management/internals/modules/zones/interface.go
Normal file
13
management/internals/modules/zones/interface.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package zones
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
|
||||
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
|
||||
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
|
||||
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
|
||||
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
|
||||
}
|
||||
161
management/internals/modules/zones/manager/api.go
Normal file
161
management/internals/modules/zones/manager/api.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager zones.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
apiZones := make([]*api.Zone, 0, len(allZones))
|
||||
for _, zone := range allZones {
|
||||
apiZones = append(apiZones, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiZones)
|
||||
}
|
||||
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiDnsZonesJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
zone.ID = zoneID
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
229
management/internals/modules/zones/manager/manager.go
Normal file
229
management/internals/modules/zones/manager/manager.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
dnsDomain string
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing zone: %w", err)
|
||||
}
|
||||
}
|
||||
if existingZone != nil {
|
||||
return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain)
|
||||
}
|
||||
|
||||
for _, groupID := range zone.DistributionGroups {
|
||||
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.CreateZone(ctx, zone); err != nil {
|
||||
return fmt.Errorf("failed to create zone: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta())
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
if zone.Domain != updatedZone.Domain {
|
||||
return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated")
|
||||
}
|
||||
|
||||
zone.Name = updatedZone.Name
|
||||
zone.Enabled = updatedZone.Enabled
|
||||
zone.EnableSearchDomain = updatedZone.EnableSearchDomain
|
||||
zone.DistributionGroups = updatedZone.DistributionGroups
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range zone.DistributionGroups {
|
||||
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.UpdateZone(ctx, zone); err != nil {
|
||||
return fmt.Errorf("failed to update zone: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get records: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete zone dns records: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteZone(ctx, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete zone: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta)
|
||||
})
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta())
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range eventsToStore {
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error {
|
||||
if m.dnsDomain != "" && m.dnsDomain == domain {
|
||||
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.DNSDomain != "" && settings.DNSDomain == domain {
|
||||
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
553
management/internals/modules/zones/manager/manager_test.go
Normal file
553
management/internals/modules/zones/manager/manager_test.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testZoneID = "test-zone-id"
|
||||
testGroupID = "test-group-id"
|
||||
testDNSDomain = "netbird.selfhosted"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testStore.SaveAccount(ctx, &types.Account{
|
||||
Id: testAccountID,
|
||||
Groups: map[string]*types.Group{
|
||||
testGroupID: {
|
||||
ID: testGroupID,
|
||||
Name: "Test Group",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
dnsDomain: testDNSDomain,
|
||||
}
|
||||
|
||||
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone1)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID})
|
||||
err = testStore.CreateZone(ctx, zone2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, zone1.ID, result[0].ID)
|
||||
assert.Equal(t, zone2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, zone.ID, result.ID)
|
||||
assert.Equal(t, zone.Name, result.Name)
|
||||
assert.Equal(t, zone.Domain, result.Domain)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotEmpty(t, result.ID)
|
||||
assert.Equal(t, testAccountID, result.AccountID)
|
||||
assert.Equal(t, inputZone.Name, result.Name)
|
||||
assert.Equal(t, inputZone.Domain, result.Domain)
|
||||
assert.Equal(t, inputZone.Enabled, result.Enabled)
|
||||
assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain)
|
||||
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("invalid group", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
DistributionGroups: []string{"invalid-group"},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("duplicate domain", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "duplicate.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, s.Type())
|
||||
})
|
||||
|
||||
t.Run("peer DNS domain conflict", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
account, err := testStore.GetAccount(ctx, testAccountID)
|
||||
require.NoError(t, err)
|
||||
account.Settings.DNSDomain = "peers.example.com"
|
||||
err = testStore.SaveAccount(ctx, account)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "Test Zone",
|
||||
Domain: "peers.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
|
||||
t.Run("default DNS domain conflict", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "Test Zone",
|
||||
Domain: testDNSDomain,
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain))
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: existingZone.ID,
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, existingZone.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneUpdated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, updatedZone.Name, result.Name)
|
||||
assert.Equal(t, updatedZone.Enabled, result.Enabled)
|
||||
assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
})
|
||||
|
||||
t.Run("domain change not allowed", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: existingZone.ID,
|
||||
Name: "Test Zone",
|
||||
Domain: "different.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone domain cannot be updated")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: testZoneID,
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: "non-existent-zone",
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success with records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCallCount := 0
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCallCount++
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
}
|
||||
|
||||
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, storeEventCallCount)
|
||||
|
||||
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
|
||||
zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, zoneRecords)
|
||||
})
|
||||
|
||||
t.Run("success without records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, zone.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneDeleted, activityID)
|
||||
}
|
||||
|
||||
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
|
||||
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
13
management/internals/modules/zones/records/interface.go
Normal file
13
management/internals/modules/zones/records/interface.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package records
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error)
|
||||
GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error)
|
||||
CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
|
||||
UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
|
||||
DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error
|
||||
}
|
||||
191
management/internals/modules/zones/records/manager/api.go
Normal file
191
management/internals/modules/zones/records/manager/api.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager records.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
apiRecords := make([]*api.DNSRecord, 0, len(allRecords))
|
||||
for _, record := range allRecords {
|
||||
apiRecords = append(apiRecords, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiRecords)
|
||||
}
|
||||
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
record.ID = recordID
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
236
management/internals/modules/zones/records/manager/manager.go
Normal file
236
management/internals/modules/zones/records/manager/manager.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
|
||||
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
err = validateRecordConflicts(ctx, transaction, zone, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreateDNSRecord(ctx, record); err != nil {
|
||||
return fmt.Errorf("failed to create dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
var record *records.Record
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get record: %w", err)
|
||||
}
|
||||
|
||||
hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content
|
||||
|
||||
record.Name = updatedRecord.Name
|
||||
record.Type = updatedRecord.Type
|
||||
record.Content = updatedRecord.Content
|
||||
record.TTL = updatedRecord.TTL
|
||||
|
||||
if hasChanges {
|
||||
if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.UpdateDNSRecord(ctx, record); err != nil {
|
||||
return fmt.Errorf("failed to update dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var record *records.Record
|
||||
var zone *zones.Zone
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRecordConflicts checks for duplicate records and CNAME conflicts
|
||||
func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error {
|
||||
if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) {
|
||||
return status.Errorf(status.InvalidArgument, "record name does not belong to zone")
|
||||
}
|
||||
|
||||
existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check existing records: %w", err)
|
||||
}
|
||||
|
||||
for _, existing := range existingRecords {
|
||||
if existing.ID == record.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if existing.Type == record.Type && existing.Content == record.Content {
|
||||
return status.Errorf(status.AlreadyExists, "identical record already exists")
|
||||
}
|
||||
|
||||
if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME {
|
||||
return status.Errorf(status.InvalidArgument,
|
||||
"An A, AAAA, or CNAME record with name %s already exists", record.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,573 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testRecordID = "test-record-id"
|
||||
testGroupID = "test-group-id"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testStore.SaveAccount(ctx, &types.Account{
|
||||
Id: testAccountID,
|
||||
Groups: map[string]*types.Group{
|
||||
testGroupID: {
|
||||
ID: testGroupID,
|
||||
Name: "Test Group",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err = testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
}
|
||||
|
||||
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, record1.ID, result[0].ID)
|
||||
assert.Equal(t, record2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.ID, result.ID)
|
||||
assert.Equal(t, record.Name, result.Name)
|
||||
assert.Equal(t, record.Type, result.Type)
|
||||
assert.Equal(t, record.Content, result.Content)
|
||||
assert.Equal(t, record.TTL, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success - A record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotEmpty(t, result.ID)
|
||||
assert.Equal(t, testAccountID, result.AccountID)
|
||||
assert.Equal(t, zone.ID, result.ZoneID)
|
||||
assert.Equal(t, inputRecord.Name, result.Name)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
assert.Equal(t, inputRecord.TTL, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("success - AAAA record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "ipv6.example.com",
|
||||
Type: records.RecordTypeAAAA,
|
||||
Content: "2001:db8::1",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
})
|
||||
|
||||
t.Run("success - CNAME record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "www.example.com",
|
||||
Type: records.RecordTypeCNAME,
|
||||
Content: "example.com",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record name not in zone", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.different.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "does not belong to zone")
|
||||
})
|
||||
|
||||
t.Run("duplicate record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "identical record already exists")
|
||||
})
|
||||
|
||||
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeCNAME,
|
||||
Content: "example.com",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: existingRecord.ID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100", // Changed IP
|
||||
TTL: 600, // Changed TTL
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, existingRecord.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordUpdated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, updatedRecord.Content, result.Content)
|
||||
assert.Equal(t, updatedRecord.TTL, result.TTL)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
})
|
||||
|
||||
t.Run("update only TTL - no validation", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: existingRecord.ID,
|
||||
Name: existingRecord.Name,
|
||||
Type: existingRecord.Type,
|
||||
Content: existingRecord.Content,
|
||||
TTL: 600, // Only TTL changed
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
// Event should be stored
|
||||
}
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, 600, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: testRecordID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: "non-existent-record",
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("update creates duplicate", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: record2.ID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "identical record already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, record.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordDeleted, activityID)
|
||||
}
|
||||
|
||||
err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
|
||||
_, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
129
management/internals/modules/zones/records/record.go
Normal file
129
management/internals/modules/zones/records/record.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package records
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type RecordType string
|
||||
|
||||
const (
|
||||
RecordTypeA RecordType = "A"
|
||||
RecordTypeAAAA RecordType = "AAAA"
|
||||
RecordTypeCNAME RecordType = "CNAME"
|
||||
)
|
||||
|
||||
type Record struct {
|
||||
AccountID string `gorm:"index"`
|
||||
ZoneID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
Name string
|
||||
Type RecordType
|
||||
Content string
|
||||
TTL int
|
||||
}
|
||||
|
||||
func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record {
|
||||
return &Record{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
ZoneID: zoneID,
|
||||
Name: name,
|
||||
Type: recordType,
|
||||
Content: content,
|
||||
TTL: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Record) ToAPIResponse() *api.DNSRecord {
|
||||
recordType := api.DNSRecordType(r.Type)
|
||||
return &api.DNSRecord{
|
||||
Id: r.ID,
|
||||
Name: r.Name,
|
||||
Type: recordType,
|
||||
Content: r.Content,
|
||||
Ttl: r.TTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) {
|
||||
r.Name = req.Name
|
||||
r.Type = RecordType(req.Type)
|
||||
r.Content = req.Content
|
||||
r.TTL = req.Ttl
|
||||
}
|
||||
|
||||
func (r *Record) Validate() error {
|
||||
if r.Name == "" {
|
||||
return errors.New("record name is required")
|
||||
}
|
||||
|
||||
if !util.IsValidDomain(r.Name) {
|
||||
return errors.New("invalid record name format")
|
||||
}
|
||||
|
||||
if r.Type == "" {
|
||||
return errors.New("record type is required")
|
||||
}
|
||||
|
||||
switch r.Type {
|
||||
case RecordTypeA:
|
||||
if err := validateIPv4(r.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
if err := validateIPv6(r.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
case RecordTypeCNAME:
|
||||
if !util.IsValidDomain(r.Content) {
|
||||
return errors.New("invalid CNAME record format")
|
||||
}
|
||||
default:
|
||||
return errors.New("invalid record type, must be A, AAAA, or CNAME")
|
||||
}
|
||||
|
||||
if r.TTL < 0 {
|
||||
return errors.New("TTL cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Record) EventMeta(zoneID, zoneName string) map[string]any {
|
||||
return map[string]any{
|
||||
"name": r.Name,
|
||||
"type": string(r.Type),
|
||||
"content": r.Content,
|
||||
"ttl": r.TTL,
|
||||
"zone_id": zoneID,
|
||||
"zone_name": zoneName,
|
||||
}
|
||||
}
|
||||
|
||||
func validateIPv4(content string) error {
|
||||
if content == "" {
|
||||
return errors.New("A record is required") //nolint:staticcheck
|
||||
}
|
||||
ip := net.ParseIP(content)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return errors.New("A record must be a valid IPv4 address") //nolint:staticcheck
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateIPv6(content string) error {
|
||||
if content == "" {
|
||||
return errors.New("AAAA record is required")
|
||||
}
|
||||
ip := net.ParseIP(content)
|
||||
if ip == nil || ip.To4() != nil {
|
||||
return errors.New("AAAA record must be a valid IPv6 address")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
89
management/internals/modules/zones/zone.go
Normal file
89
management/internals/modules/zones/zone.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package zones
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type Zone struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string
|
||||
Enabled bool
|
||||
EnableSearchDomain bool
|
||||
DistributionGroups []string `gorm:"serializer:json"`
|
||||
Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"`
|
||||
}
|
||||
|
||||
func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone {
|
||||
return &Zone{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Domain: domain,
|
||||
Enabled: enabled,
|
||||
EnableSearchDomain: enableSearchDomain,
|
||||
DistributionGroups: distributionGroups,
|
||||
}
|
||||
}
|
||||
|
||||
func (z *Zone) ToAPIResponse() *api.Zone {
|
||||
apiRecords := make([]api.DNSRecord, 0, len(z.Records))
|
||||
for _, record := range z.Records {
|
||||
if apiRecord := record.ToAPIResponse(); apiRecord != nil {
|
||||
apiRecords = append(apiRecords, *apiRecord)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.Zone{
|
||||
DistributionGroups: z.DistributionGroups,
|
||||
Domain: z.Domain,
|
||||
EnableSearchDomain: z.EnableSearchDomain,
|
||||
Enabled: z.Enabled,
|
||||
Id: z.ID,
|
||||
Name: z.Name,
|
||||
Records: apiRecords,
|
||||
}
|
||||
}
|
||||
|
||||
func (z *Zone) FromAPIRequest(req *api.ZoneRequest) {
|
||||
z.Name = req.Name
|
||||
z.Domain = req.Domain
|
||||
z.EnableSearchDomain = req.EnableSearchDomain
|
||||
z.DistributionGroups = req.DistributionGroups
|
||||
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
z.Enabled = enabled
|
||||
}
|
||||
|
||||
func (z *Zone) Validate() error {
|
||||
if z.Name == "" {
|
||||
return errors.New("zone name is required")
|
||||
}
|
||||
if len(z.Name) > 255 {
|
||||
return errors.New("zone name exceeds maximum length of 255 characters")
|
||||
}
|
||||
|
||||
if !util.IsValidDomain(z.Domain) {
|
||||
return errors.New("invalid zone domain format")
|
||||
}
|
||||
|
||||
if len(z.DistributionGroups) == 0 {
|
||||
return errors.New("at least one distribution group is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *Zone) EventMeta() map[string]any {
|
||||
return map[string]any{"name": z.Name, "domain": z.Domain}
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
@@ -158,3 +162,15 @@ func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ZonesManager() zones.Manager {
|
||||
return Create(s, func() zones.Manager {
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RecordsManager() records.Manager {
|
||||
return Create(s, func() records.Manager {
|
||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -376,6 +376,7 @@ func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
@@ -433,9 +434,16 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi
|
||||
if config.CLIAuthAudience != "" {
|
||||
audience = config.CLIAuthAudience
|
||||
}
|
||||
|
||||
audiences := []string{config.AuthAudience}
|
||||
if config.CLIAuthAudience != "" && config.CLIAuthAudience != config.AuthAudience {
|
||||
audiences = append(audiences, config.CLIAuthAudience)
|
||||
}
|
||||
|
||||
return &proto.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audiences: audiences,
|
||||
KeysLocation: keysLocation,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
@@ -148,3 +151,51 @@ func generateTestData(size int) nbdns.Config {
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authAudience string
|
||||
cliAuthAudience string
|
||||
expectedAudiences []string
|
||||
expectedAudience string
|
||||
}{
|
||||
{
|
||||
name: "only_auth_audience",
|
||||
authAudience: "dashboard-aud",
|
||||
cliAuthAudience: "",
|
||||
expectedAudiences: []string{"dashboard-aud"},
|
||||
expectedAudience: "dashboard-aud",
|
||||
},
|
||||
{
|
||||
name: "both_audiences_different",
|
||||
authAudience: "dashboard-aud",
|
||||
cliAuthAudience: "cli-aud",
|
||||
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
|
||||
expectedAudience: "cli-aud",
|
||||
},
|
||||
{
|
||||
name: "both_audiences_same",
|
||||
authAudience: "same-aud",
|
||||
cliAuthAudience: "same-aud",
|
||||
expectedAudiences: []string{"same-aud"},
|
||||
expectedAudience: "same-aud",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := &nbconfig.HttpServerConfig{
|
||||
AuthIssuer: "https://issuer.example.com",
|
||||
AuthAudience: tc.authAudience,
|
||||
CLIAuthAudience: tc.cliAuthAudience,
|
||||
}
|
||||
|
||||
result := buildJWTConfig(config, nil)
|
||||
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
|
||||
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,6 +470,16 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
})
|
||||
}
|
||||
|
||||
diskEncryptionVolumes := make([]nbpeer.DiskEncryptionVolume, 0)
|
||||
if meta.GetDiskEncryption() != nil {
|
||||
for _, vol := range meta.GetDiskEncryption().GetVolumes() {
|
||||
diskEncryptionVolumes = append(diskEncryptionVolumes, nbpeer.DiskEncryptionVolume{
|
||||
Path: vol.GetPath(),
|
||||
Encrypted: vol.GetEncrypted(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nbpeer.PeerSystemMeta{
|
||||
Hostname: meta.GetHostname(),
|
||||
GoOS: meta.GetGoOS(),
|
||||
@@ -501,6 +511,9 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
|
||||
},
|
||||
Files: files,
|
||||
DiskEncryption: nbpeer.DiskEncryptionInfo{
|
||||
Volumes: diskEncryptionVolumes,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -295,7 +295,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
|
||||
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -388,7 +388,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return newSettings, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@@ -402,6 +402,18 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, new
|
||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||
}
|
||||
|
||||
if newSettings.DNSDomain != oldSettings.DNSDomain && newSettings.DNSDomain != "" {
|
||||
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, newSettings.DNSDomain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing zone: %w", err)
|
||||
}
|
||||
}
|
||||
if existingZone != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer DNS domain %s conflicts with existing custom DNS zone", newSettings.DNSDomain)
|
||||
}
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -397,7 +398,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -1676,7 +1677,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
@@ -1686,7 +1687,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
@@ -2095,6 +2096,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_DNSDomainConflict(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
ctx := context.Background()
|
||||
err = manager.Store.CreateZone(ctx, &zones.Zone{
|
||||
ID: "test-zone-id",
|
||||
AccountID: accountID,
|
||||
Name: "Test Zone",
|
||||
Domain: "custom.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{},
|
||||
})
|
||||
require.NoError(t, err, "unable to create custom DNS zone")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
|
||||
DNSDomain: "custom.example.com",
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when DNS domain conflicts with custom zone")
|
||||
assert.Contains(t, err.Error(), "conflicts with existing custom DNS zone")
|
||||
}
|
||||
|
||||
func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
|
||||
@@ -187,6 +187,14 @@ const (
|
||||
IdentityProviderUpdated Activity = 94
|
||||
IdentityProviderDeleted Activity = 95
|
||||
|
||||
DNSZoneCreated Activity = 96
|
||||
DNSZoneUpdated Activity = 97
|
||||
DNSZoneDeleted Activity = 98
|
||||
|
||||
DNSRecordCreated Activity = 99
|
||||
DNSRecordUpdated Activity = 100
|
||||
DNSRecordDeleted Activity = 101
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -303,6 +311,14 @@ var activityMap = map[Activity]Code{
|
||||
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
|
||||
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
|
||||
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
|
||||
|
||||
DNSZoneCreated: {"DNS zone created", "dns.zone.create"},
|
||||
DNSZoneUpdated: {"DNS zone updated", "dns.zone.update"},
|
||||
DNSZoneDeleted: {"DNS zone deleted", "dns.zone.delete"},
|
||||
|
||||
DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"},
|
||||
DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"},
|
||||
DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
25
management/server/cache/idp.go
vendored
25
management/server/cache/idp.go
vendored
@@ -26,6 +26,8 @@ type UserDataCache interface {
|
||||
Get(ctx context.Context, key string) (*idp.UserData, error)
|
||||
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
|
||||
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
|
||||
}
|
||||
|
||||
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
|
||||
@@ -51,6 +53,29 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
|
||||
return u.cache.Delete(ctx, key)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
|
||||
var users []*idp.UserData
|
||||
v, err := u.cache.Get(ctx, key, &users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case []*idp.UserData:
|
||||
return v, nil
|
||||
case *[]*idp.UserData:
|
||||
return *v, nil
|
||||
case []byte:
|
||||
return unmarshalUserData(v)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected type: %T", v)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
|
||||
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
|
||||
}
|
||||
|
||||
// NewUserDataCache creates a new UserDataCacheImpl object.
|
||||
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
|
||||
simpleCache := cache.New[any](store)
|
||||
|
||||
@@ -15,7 +15,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
@@ -56,7 +59,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -138,6 +141,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -298,8 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
ua := strings.ToLower(r.Header.Get("User-Agent"))
|
||||
return strings.Contains(ua, "terraform")
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
|
||||
@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||
})
|
||||
|
||||
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test various Terraform user agent formats
|
||||
terraformUserAgents := []string{
|
||||
"Terraform/1.5.0",
|
||||
"terraform/1.0.0",
|
||||
"Terraform-Provider/2.0.0",
|
||||
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
|
||||
}
|
||||
|
||||
for _, userAgent := range terraformUserAgents {
|
||||
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
|
||||
successCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
@@ -93,8 +95,10 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
routersManagerMock := routers.NewManagerMock()
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -136,9 +136,10 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.AuthIssuer == "" {
|
||||
|
||||
@@ -48,13 +48,12 @@ type AuthentikCredentials struct {
|
||||
}
|
||||
|
||||
// NewAuthentikManager creates a new instance of the AuthentikManager.
|
||||
func NewAuthentikManager(config AuthentikClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
||||
func NewAuthentikManager(config AuthentikClientConfig, appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
|
||||
@@ -58,9 +58,10 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2/google"
|
||||
@@ -49,9 +48,10 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.CustomerID == "" {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
||||
|
||||
@@ -46,9 +45,10 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.APIToken == "" {
|
||||
|
||||
@@ -63,9 +63,10 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
||||
@@ -45,7 +44,7 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
@@ -88,9 +87,10 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ManagementEndpoint == "" {
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,3 +71,24 @@ func baseURL(rawURL string) string {
|
||||
|
||||
return parsedURL.Scheme + "://" + parsedURL.Host
|
||||
}
|
||||
|
||||
const (
|
||||
// Provides the env variable name for use with idpTimeout function
|
||||
idpTimeoutEnv = "NB_IDP_TIMEOUT"
|
||||
// Sets the defaultTimeout to 10s.
|
||||
defaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// idpTimeout returns a timeout value for the IDP
|
||||
func idpTimeout() time.Duration {
|
||||
timeoutStr, ok := os.LookupEnv(idpTimeoutEnv)
|
||||
if !ok || timeoutStr == "" {
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
timeout, err := time.ParseDuration(timeoutStr)
|
||||
if err != nil {
|
||||
return defaultTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
@@ -164,9 +164,10 @@ func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetri
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: idpTimeout(),
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
hasPAT := config.PAT != ""
|
||||
|
||||
@@ -95,6 +95,27 @@ type File struct {
|
||||
ProcessIsRunning bool
|
||||
}
|
||||
|
||||
// DiskEncryptionVolume represents encryption status of a volume.
|
||||
type DiskEncryptionVolume struct {
|
||||
Path string
|
||||
Encrypted bool
|
||||
}
|
||||
|
||||
// DiskEncryptionInfo holds encryption info for all volumes.
|
||||
type DiskEncryptionInfo struct {
|
||||
Volumes []DiskEncryptionVolume `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// IsEncrypted returns true if the volume at path is encrypted.
|
||||
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
|
||||
for _, v := range d.Volumes {
|
||||
if v.Path == path {
|
||||
return v.Encrypted
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Flags defines a set of options to control feature behavior
|
||||
type Flags struct {
|
||||
RosenpassEnabled bool
|
||||
@@ -130,6 +151,7 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
Environment Environment `gorm:"serializer:json"`
|
||||
Flags Flags `gorm:"serializer:json"`
|
||||
Files []File `gorm:"serializer:json"`
|
||||
DiskEncryption DiskEncryptionInfo `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
@@ -159,6 +181,19 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Slice(p.DiskEncryption.Volumes, func(i, j int) bool {
|
||||
return p.DiskEncryption.Volumes[i].Path < p.DiskEncryption.Volumes[j].Path
|
||||
})
|
||||
sort.Slice(other.DiskEncryption.Volumes, func(i, j int) bool {
|
||||
return other.DiskEncryption.Volumes[i].Path < other.DiskEncryption.Volumes[j].Path
|
||||
})
|
||||
equalDiskEncryption := slices.EqualFunc(p.DiskEncryption.Volumes, other.DiskEncryption.Volumes, func(vol DiskEncryptionVolume, oVol DiskEncryptionVolume) bool {
|
||||
return vol.Path == oVol.Path && vol.Encrypted == oVol.Encrypted
|
||||
})
|
||||
if !equalDiskEncryption {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.Hostname == other.Hostname &&
|
||||
p.GoOS == other.GoOS &&
|
||||
p.Kernel == other.Kernel &&
|
||||
|
||||
@@ -18,6 +18,7 @@ const (
|
||||
GeoLocationCheckName = "GeoLocationCheck"
|
||||
PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
|
||||
ProcessCheckName = "ProcessCheck"
|
||||
DiskEncryptionCheckName = "DiskEncryptionCheck"
|
||||
|
||||
CheckActionAllow string = "allow"
|
||||
CheckActionDeny string = "deny"
|
||||
@@ -58,6 +59,7 @@ type ChecksDefinition struct {
|
||||
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
|
||||
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
|
||||
ProcessCheck *ProcessCheck `json:",omitempty"`
|
||||
DiskEncryptionCheck *DiskEncryptionCheck `json:",omitempty"`
|
||||
}
|
||||
|
||||
// Copy returns a copy of a checks definition.
|
||||
@@ -110,6 +112,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
|
||||
}
|
||||
copy(cdCopy.ProcessCheck.Processes, processCheck.Processes)
|
||||
}
|
||||
if cd.DiskEncryptionCheck != nil {
|
||||
cdCopy.DiskEncryptionCheck = &DiskEncryptionCheck{
|
||||
LinuxPath: cd.DiskEncryptionCheck.LinuxPath,
|
||||
DarwinPath: cd.DiskEncryptionCheck.DarwinPath,
|
||||
WindowsPath: cd.DiskEncryptionCheck.WindowsPath,
|
||||
}
|
||||
}
|
||||
return cdCopy
|
||||
}
|
||||
|
||||
@@ -153,6 +162,9 @@ func (pc *Checks) GetChecks() []Check {
|
||||
if pc.Checks.ProcessCheck != nil {
|
||||
checks = append(checks, pc.Checks.ProcessCheck)
|
||||
}
|
||||
if pc.Checks.DiskEncryptionCheck != nil {
|
||||
checks = append(checks, pc.Checks.DiskEncryptionCheck)
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
@@ -208,6 +220,10 @@ func buildPostureCheck(postureChecksID string, name string, description string,
|
||||
postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck)
|
||||
}
|
||||
|
||||
if diskEncryptionCheck := checks.DiskEncryptionCheck; diskEncryptionCheck != nil {
|
||||
postureChecks.Checks.DiskEncryptionCheck = toDiskEncryptionCheck(diskEncryptionCheck)
|
||||
}
|
||||
|
||||
return &postureChecks, nil
|
||||
}
|
||||
|
||||
@@ -242,6 +258,10 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck {
|
||||
checks.ProcessCheck = toProcessCheckResponse(pc.Checks.ProcessCheck)
|
||||
}
|
||||
|
||||
if pc.Checks.DiskEncryptionCheck != nil {
|
||||
checks.DiskEncryptionCheck = toDiskEncryptionCheckResponse(pc.Checks.DiskEncryptionCheck)
|
||||
}
|
||||
|
||||
return &api.PostureCheck{
|
||||
Id: pc.ID,
|
||||
Name: pc.Name,
|
||||
@@ -386,3 +406,25 @@ func toProcessCheck(check *api.ProcessCheck) *ProcessCheck {
|
||||
Processes: processes,
|
||||
}
|
||||
}
|
||||
|
||||
func toDiskEncryptionCheck(check *api.DiskEncryptionCheck) *DiskEncryptionCheck {
|
||||
d := &DiskEncryptionCheck{}
|
||||
if check.LinuxPath != nil {
|
||||
d.LinuxPath = *check.LinuxPath
|
||||
}
|
||||
if check.DarwinPath != nil {
|
||||
d.DarwinPath = *check.DarwinPath
|
||||
}
|
||||
if check.WindowsPath != nil {
|
||||
d.WindowsPath = *check.WindowsPath
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func toDiskEncryptionCheckResponse(check *DiskEncryptionCheck) *api.DiskEncryptionCheck {
|
||||
return &api.DiskEncryptionCheck{
|
||||
LinuxPath: &check.LinuxPath,
|
||||
DarwinPath: &check.DarwinPath,
|
||||
WindowsPath: &check.WindowsPath,
|
||||
}
|
||||
}
|
||||
|
||||
52
management/server/posture/disk_encryption.go
Normal file
52
management/server/posture/disk_encryption.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
// DiskEncryptionCheck verifies that specified volumes are encrypted.
|
||||
type DiskEncryptionCheck struct {
|
||||
LinuxPath string
|
||||
DarwinPath string
|
||||
WindowsPath string
|
||||
}
|
||||
|
||||
var _ Check = (*DiskEncryptionCheck)(nil)
|
||||
|
||||
// Name returns the name of the check.
|
||||
func (d *DiskEncryptionCheck) Name() string {
|
||||
return DiskEncryptionCheckName
|
||||
}
|
||||
|
||||
// Check performs the disk encryption verification for the given peer.
|
||||
func (d *DiskEncryptionCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) {
|
||||
var pathToCheck string
|
||||
|
||||
switch peer.Meta.GoOS {
|
||||
case "linux":
|
||||
pathToCheck = d.LinuxPath
|
||||
case "darwin":
|
||||
pathToCheck = d.DarwinPath
|
||||
case "windows":
|
||||
pathToCheck = d.WindowsPath
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if pathToCheck == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return peer.Meta.DiskEncryption.IsEncrypted(pathToCheck), nil
|
||||
}
|
||||
|
||||
// Validate checks the configuration of the disk encryption check.
|
||||
func (d *DiskEncryptionCheck) Validate() error {
|
||||
if d.LinuxPath == "" && d.DarwinPath == "" && d.WindowsPath == "" {
|
||||
return fmt.Errorf("%s at least one path must be configured", d.Name())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
306
management/server/posture/disk_encryption_test.go
Normal file
306
management/server/posture/disk_encryption_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestDiskEncryptionCheck_Check(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input peer.Peer
|
||||
check DiskEncryptionCheck
|
||||
wantErr bool
|
||||
isValid bool
|
||||
}{
|
||||
{
|
||||
name: "linux with encrypted root",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "/", Encrypted: true},
|
||||
{Path: "/home", Encrypted: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "linux with unencrypted root",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "/", Encrypted: false},
|
||||
{Path: "/home", Encrypted: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "linux with no volume info",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "darwin with encrypted root",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "darwin",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "/", Encrypted: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
DarwinPath: "/",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "darwin with unencrypted root",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "darwin",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "/", Encrypted: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
DarwinPath: "/",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "windows with encrypted C drive",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "C:", Encrypted: true},
|
||||
{Path: "D:", Encrypted: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
WindowsPath: "C:",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "windows with unencrypted C drive",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "C:", Encrypted: false},
|
||||
{Path: "D:", Encrypted: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
WindowsPath: "C:",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported ios operating system",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "ios",
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
DarwinPath: "/",
|
||||
WindowsPath: "C:",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported android operating system",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "android",
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
DarwinPath: "/",
|
||||
WindowsPath: "C:",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "linux peer with no linux path configured passes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "/", Encrypted: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
DarwinPath: "/",
|
||||
WindowsPath: "C:",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "darwin peer with no darwin path configured passes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "darwin",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "/", Encrypted: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
WindowsPath: "C:",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "windows peer with no windows path configured passes",
|
||||
input: peer.Peer{
|
||||
Meta: peer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
DiskEncryption: peer.DiskEncryptionInfo{
|
||||
Volumes: []peer.DiskEncryptionVolume{
|
||||
{Path: "C:", Encrypted: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
DarwinPath: "/",
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid, err := tt.check.Check(context.Background(), tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.isValid, isValid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiskEncryptionCheck_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
check DiskEncryptionCheck
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "valid linux, darwin and windows paths",
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
DarwinPath: "/",
|
||||
WindowsPath: "C:",
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "valid linux path only",
|
||||
check: DiskEncryptionCheck{
|
||||
LinuxPath: "/",
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "valid darwin path only",
|
||||
check: DiskEncryptionCheck{
|
||||
DarwinPath: "/",
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "valid windows path only",
|
||||
check: DiskEncryptionCheck{
|
||||
WindowsPath: "C:",
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid empty paths",
|
||||
check: DiskEncryptionCheck{},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.check.Validate()
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiskEncryptionCheck_Name(t *testing.T) {
|
||||
check := DiskEncryptionCheck{}
|
||||
assert.Equal(t, DiskEncryptionCheckName, check.Name())
|
||||
}
|
||||
@@ -27,6 +27,8 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -123,6 +125,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&zones.Zone{}, &records.Record{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -4179,3 +4182,184 @@ func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingS
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateZone(ctx context.Context, zone *zones.Zone) error {
|
||||
result := s.db.Create(zone)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create zone to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to create zone to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateZone(ctx context.Context, zone *zones.Zone) error {
|
||||
result := s.db.Select("*").Save(zone)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update zone to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update zone to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteZone(ctx context.Context, accountID, zoneID string) error {
|
||||
result := s.db.Delete(&zones.Zone{}, accountAndIDQueryCondition, accountID, zoneID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete zone from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete zone from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewZoneNotFoundError(zoneID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
result := tx.Preload("Records").Take(&zone, accountAndIDQueryCondition, accountID, zoneID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewZoneNotFoundError(zoneID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get zone from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zone from store")
|
||||
}
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error) {
|
||||
var zone *zones.Zone
|
||||
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&zone)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewZoneNotFoundError(domain)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get zone by domain from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zone by domain from store")
|
||||
}
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var zones []*zones.Zone
|
||||
result := tx.Preload("Records").Find(&zones, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get zones from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zones from store")
|
||||
}
|
||||
|
||||
return zones, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateDNSRecord(ctx context.Context, record *records.Record) error {
|
||||
result := s.db.Create(record)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create dns record to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to create dns record to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateDNSRecord(ctx context.Context, record *records.Record) error {
|
||||
result := s.db.Select("*").Save(record)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update dns record to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update dns record to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error {
|
||||
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete dns record from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete dns record from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewDNSRecordNotFoundError(recordID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var record *records.Record
|
||||
result := tx.Where("account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID).Take(&record)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewDNSRecordNotFoundError(recordID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get dns record from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns record from store")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var recordsList []*records.Record
|
||||
result := tx.Where("account_id = ? AND zone_id = ?", accountID, zoneID).Find(&recordsList)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get zone dns records from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zone dns records from store")
|
||||
}
|
||||
|
||||
return recordsList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var recordsList []*records.Record
|
||||
result := tx.Where("account_id = ? AND zone_id = ? AND name = ?", accountID, zoneID, name).Find(&recordsList)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get zone dns records by name from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zone dns records by name from store")
|
||||
}
|
||||
|
||||
return recordsList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error {
|
||||
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ?", accountID, zoneID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete zone dns records from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete zone dns records from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -4025,3 +4027,476 @@ func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
|
||||
}
|
||||
|
||||
func TestSqlStore_CreateZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, zone.ID, savedZone.ID)
|
||||
assert.Equal(t, zone.Name, savedZone.Name)
|
||||
assert.Equal(t, zone.Domain, savedZone.Domain)
|
||||
assert.Equal(t, zone.Enabled, savedZone.Enabled)
|
||||
assert.Equal(t, zone.EnableSearchDomain, savedZone.EnableSearchDomain)
|
||||
assert.Equal(t, zone.DistributionGroups, savedZone.DistributionGroups)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
zoneID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing zone",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing zone",
|
||||
accountID: accountID,
|
||||
zoneID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty zone ID",
|
||||
accountID: accountID,
|
||||
zoneID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, savedZone)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, tt.zoneID, savedZone.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountZones(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone1 := zones.NewZone(accountID, "Zone 1", "example1.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone1)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone2 := zones.NewZone(accountID, "Zone 2", "example2.com", true, true, []string{"group1", "group2"})
|
||||
err = store.CreateZone(context.Background(), zone2)
|
||||
require.NoError(t, err)
|
||||
|
||||
allZones, err := store.GetAccountZones(context.Background(), LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, allZones)
|
||||
assert.GreaterOrEqual(t, len(allZones), 2)
|
||||
|
||||
zoneIDs := make(map[string]bool)
|
||||
for _, z := range allZones {
|
||||
zoneIDs[z.ID] = true
|
||||
}
|
||||
assert.True(t, zoneIDs[zone1.ID])
|
||||
assert.True(t, zoneIDs[zone2.ID])
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneByDomain(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
otherAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3c"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
domain string
|
||||
expectError bool
|
||||
errorType status.Type
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing zone by domain",
|
||||
accountID: accountID,
|
||||
domain: "example.com",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing zone domain",
|
||||
accountID: accountID,
|
||||
domain: "non-existing.com",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty domain",
|
||||
accountID: accountID,
|
||||
domain: "",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
{
|
||||
name: "retrieve with different account ID",
|
||||
accountID: otherAccountID,
|
||||
domain: "example.com",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedZone, err := store.GetZoneByDomain(context.Background(), tt.accountID, tt.domain)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.errorType, sErr.Type())
|
||||
require.Nil(t, savedZone)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, tt.domain, savedZone.Domain)
|
||||
assert.Equal(t, zone.ID, savedZone.ID)
|
||||
assert.Equal(t, zone.Name, savedZone.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone.Name = "Updated Zone"
|
||||
zone.Domain = "updated.com"
|
||||
zone.Enabled = false
|
||||
zone.EnableSearchDomain = true
|
||||
zone.DistributionGroups = []string{"group2", "group3"}
|
||||
|
||||
err = store.UpdateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedZone)
|
||||
assert.Equal(t, "Updated Zone", updatedZone.Name)
|
||||
assert.Equal(t, "updated.com", updatedZone.Domain)
|
||||
assert.False(t, updatedZone.Enabled)
|
||||
assert.True(t, updatedZone.EnableSearchDomain)
|
||||
assert.Equal(t, []string{"group2", "group3"}, updatedZone.DistributionGroups)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.DeleteZone(context.Background(), accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
deletedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, deletedZone)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
}
|
||||
|
||||
func TestSqlStore_CreateDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedRecord)
|
||||
assert.Equal(t, record.ID, savedRecord.ID)
|
||||
assert.Equal(t, record.Name, savedRecord.Name)
|
||||
assert.Equal(t, record.Type, savedRecord.Type)
|
||||
assert.Equal(t, record.Content, savedRecord.Content)
|
||||
assert.Equal(t, record.TTL, savedRecord.TTL)
|
||||
assert.Equal(t, zone.ID, savedRecord.ZoneID)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetDNSRecordByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
zoneID string
|
||||
recordID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing record",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: record.ID,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing record",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty record ID",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID, tt.recordID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, savedRecord)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedRecord)
|
||||
assert.Equal(t, tt.recordID, savedRecord.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneDNSRecords(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordA := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordA)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordAAAA := records.NewRecord(accountID, zone.ID, "ipv6.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordAAAA)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordCNAME := records.NewRecord(accountID, zone.ID, "alias.example.com", records.RecordTypeCNAME, "www.example.com", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordCNAME)
|
||||
require.NoError(t, err)
|
||||
|
||||
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, allRecords)
|
||||
assert.Equal(t, 3, len(allRecords))
|
||||
|
||||
recordIDs := make(map[string]bool)
|
||||
for _, r := range allRecords {
|
||||
recordIDs[r.ID] = true
|
||||
}
|
||||
assert.True(t, recordIDs[recordA.ID])
|
||||
assert.True(t, recordIDs[recordAAAA.ID])
|
||||
assert.True(t, recordIDs[recordCNAME.ID])
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneDNSRecordsByName(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
record3 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
|
||||
err = store.CreateDNSRecord(context.Background(), record3)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordsByName, err := store.GetZoneDNSRecordsByName(context.Background(), LockingStrengthNone, accountID, zone.ID, "www.example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, recordsByName)
|
||||
assert.Equal(t, 2, len(recordsByName))
|
||||
|
||||
for _, r := range recordsByName {
|
||||
assert.Equal(t, "www.example.com", r.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
record.Name = "api.example.com"
|
||||
record.Content = "192.168.1.100"
|
||||
record.TTL = 600
|
||||
|
||||
err = store.UpdateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedRecord)
|
||||
assert.Equal(t, "api.example.com", updatedRecord.Name)
|
||||
assert.Equal(t, "192.168.1.100", updatedRecord.Content)
|
||||
assert.Equal(t, 600, updatedRecord.TTL)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.DeleteDNSRecord(context.Background(), accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
deletedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, deletedRecord)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
|
||||
err = store.CreateDNSRecord(context.Background(), record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(allRecords))
|
||||
|
||||
err = store.DeleteZoneDNSRecords(context.Background(), accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
remainingRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(remainingRecords))
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -209,6 +211,21 @@ type Store interface {
|
||||
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
|
||||
SetFieldEncrypt(enc *crypt.FieldEncrypt)
|
||||
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
|
||||
|
||||
CreateZone(ctx context.Context, zone *zones.Zone) error
|
||||
UpdateZone(ctx context.Context, zone *zones.Zone) error
|
||||
DeleteZone(ctx context.Context, accountID, zoneID string) error
|
||||
GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error)
|
||||
GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error)
|
||||
GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error)
|
||||
|
||||
CreateDNSRecord(ctx context.Context, record *records.Record) error
|
||||
UpdateDNSRecord(ctx context.Context, record *records.Record) error
|
||||
DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error
|
||||
GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error)
|
||||
GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error)
|
||||
GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error)
|
||||
DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -150,17 +152,16 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
peerRoutesMembership := make(LookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
groupListMap := a.GetPeerGroups(peerID)
|
||||
for _, peer := range aclPeers {
|
||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
|
||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||
routes = append(routes, filteredRoutes...)
|
||||
}
|
||||
@@ -274,6 +275,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
@@ -294,6 +296,8 @@ func (a *Account) GetPeerNetworkMap(
|
||||
}
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
|
||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
@@ -307,7 +311,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
peersToConnect = append(peersToConnect, p)
|
||||
}
|
||||
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
|
||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||
@@ -323,6 +327,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
@@ -330,6 +335,10 @@ func (a *Account) GetPeerNetworkMap(
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
|
||||
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
zones = append(zones, filteredAccountZones...)
|
||||
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
@@ -1881,3 +1890,66 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
|
||||
// filterPeerAppliedZones filters account zones based on the peer's group membership
|
||||
func filterPeerAppliedZones(ctx context.Context, accountZones []*zones.Zone, peerGroups LookupMap) []nbdns.CustomZone {
|
||||
var customZones []nbdns.CustomZone
|
||||
|
||||
if len(peerGroups) == 0 {
|
||||
return customZones
|
||||
}
|
||||
|
||||
for _, zone := range accountZones {
|
||||
if !zone.Enabled || len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
hasAccess := false
|
||||
for _, distGroupID := range zone.DistributionGroups {
|
||||
if _, found := peerGroups[distGroupID]; found {
|
||||
hasAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAccess {
|
||||
continue
|
||||
}
|
||||
|
||||
simpleRecords := make([]nbdns.SimpleRecord, 0, len(zone.Records))
|
||||
for _, record := range zone.Records {
|
||||
var recordType int
|
||||
rData := record.Content
|
||||
|
||||
switch record.Type {
|
||||
case records.RecordTypeA:
|
||||
recordType = int(dns.TypeA)
|
||||
case records.RecordTypeAAAA:
|
||||
recordType = int(dns.TypeAAAA)
|
||||
case records.RecordTypeCNAME:
|
||||
recordType = int(dns.TypeCNAME)
|
||||
rData = dns.Fqdn(record.Content)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("unknown DNS record type %s for record %s", record.Type, record.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
simpleRecords = append(simpleRecords, nbdns.SimpleRecord{
|
||||
Name: dns.Fqdn(record.Name),
|
||||
Type: recordType,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: record.TTL,
|
||||
RData: rData,
|
||||
})
|
||||
}
|
||||
|
||||
customZones = append(customZones, nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(zone.Domain),
|
||||
Records: simpleRecords,
|
||||
SearchDomainDisabled: !zone.EnableSearchDomain,
|
||||
NonAuthoritative: true,
|
||||
})
|
||||
}
|
||||
|
||||
return customZones
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -1425,3 +1427,515 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_filterPeerAppliedZones(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountZones []*zones.Zone
|
||||
peerGroups LookupMap
|
||||
expected []nbdns.CustomZone
|
||||
}{
|
||||
{
|
||||
name: "empty peer groups returns empty custom zones",
|
||||
accountZones: []*zones.Zone{},
|
||||
peerGroups: LookupMap{},
|
||||
expected: []nbdns.CustomZone{},
|
||||
},
|
||||
{
|
||||
name: "peer has access to zone with A record",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.example.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "peer has access to zone with search domain enabled",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "internal.local",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "api.internal.local",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "10.0.0.1",
|
||||
TTL: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "internal.local.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "api.internal.local.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 600,
|
||||
RData: "10.0.0.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "peer has no access to zone",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "private.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group2"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "secret.private.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{},
|
||||
},
|
||||
{
|
||||
name: "disabled zone is filtered out",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "disabled.com",
|
||||
Enabled: false,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.disabled.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{},
|
||||
},
|
||||
{
|
||||
name: "zone with no records is filtered out",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "empty.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{},
|
||||
},
|
||||
{
|
||||
name: "peer has access via multiple groups",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "multi.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1", "group2", "group3"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.multi.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group2": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "multi.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.multi.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple zones with mixed access",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "allowed.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.allowed.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "zone2",
|
||||
Domain: "denied.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group2"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "www.denied.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.2",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "allowed.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.allowed.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zone with multiple record types",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "mixed.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.mixed.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "ipv6.mixed.com",
|
||||
Type: records.RecordTypeAAAA,
|
||||
Content: "2001:db8::1",
|
||||
TTL: 600,
|
||||
},
|
||||
{
|
||||
ID: "record3",
|
||||
Name: "alias.mixed.com",
|
||||
Type: records.RecordTypeCNAME,
|
||||
Content: "www.mixed.com",
|
||||
TTL: 900,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "mixed.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.mixed.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
Name: "ipv6.mixed.com.",
|
||||
Type: int(dns.TypeAAAA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 600,
|
||||
RData: "2001:db8::1",
|
||||
},
|
||||
{
|
||||
Name: "alias.mixed.com.",
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 900,
|
||||
RData: "www.mixed.com.",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple zones both accessible",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "first.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.first.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "zone2",
|
||||
Domain: "second.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "www.second.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.2",
|
||||
TTL: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "first.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.first.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: false,
|
||||
},
|
||||
{
|
||||
Domain: "second.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.second.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 600,
|
||||
RData: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zone with multiple records of same type",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "multi-a.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.multi-a.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "www.multi-a.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.2",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "multi-a.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.multi-a.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
Name: "www.multi-a.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "peer in multiple groups accessing different zones",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "zone1.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.zone1.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "zone2",
|
||||
Domain: "zone2.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group2"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "www.zone2.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.2",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}, "group2": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "zone1.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.zone1.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
{
|
||||
Domain: "zone2.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.zone2.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterPeerAppliedZones(ctx, tt.accountZones, tt.peerGroups)
|
||||
require.Equal(t, len(tt.expected), len(result), "number of custom zones should match")
|
||||
|
||||
for i, expectedZone := range tt.expected {
|
||||
assert.Equal(t, expectedZone.Domain, result[i].Domain, "domain should match")
|
||||
assert.Equal(t, expectedZone.SearchDomainDisabled, result[i].SearchDomainDisabled, "search domain disabled flag should match")
|
||||
assert.Equal(t, len(expectedZone.Records), len(result[i].Records), "number of records should match")
|
||||
|
||||
for j, expectedRecord := range expectedZone.Records {
|
||||
assert.Equal(t, expectedRecord.Name, result[i].Records[j].Name, "record name should match")
|
||||
assert.Equal(t, expectedRecord.Type, result[i].Records[j].Type, "record type should match")
|
||||
assert.Equal(t, expectedRecord.Class, result[i].Records[j].Class, "record class should match")
|
||||
assert.Equal(t, expectedRecord.TTL, result[i].Records[j].TTL, "record TTL should match")
|
||||
assert.Equal(t, expectedRecord.RData, result[i].Records[j].RData, "record RData should match")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
@@ -25,11 +26,12 @@ func (a *Account) GetPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeers map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||
|
||||
@@ -70,13 +70,13 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -115,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
b.Run("old builder", func(b *testing.B) {
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -124,7 +124,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
for range b.N {
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
for _, peerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -177,7 +177,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -185,7 +185,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
err = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||
require.NoError(t, err, "error adding peer to cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -240,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -250,7 +250,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -317,7 +317,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -325,7 +325,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
err = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||
require.NoError(t, err, "error adding router to cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -402,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -412,7 +412,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -458,7 +458,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -466,7 +466,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
err = builder.OnPeerDeleted(account, deletedPeerID)
|
||||
require.NoError(t, err, "error deleting peer from cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -537,7 +537,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -545,7 +545,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
err = builder.OnPeerDeleted(account, deletedRouterID)
|
||||
require.NoError(t, err, "error deleting routing peer from cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -597,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
b.Run("old builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -607,7 +607,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerDeleted(account, deletedPeerID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -944,7 +944,7 @@ func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -1033,7 +1034,7 @@ func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account {
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
|
||||
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone,
|
||||
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
@@ -1057,7 +1058,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
|
||||
nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, accountZones, validatedPeers)
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
||||
@@ -1074,8 +1075,8 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) assembleNetworkMap(
|
||||
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
||||
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
||||
ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
||||
dnsConfig *nbdns.Config, sshView *PeerSSHView, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{},
|
||||
) *NetworkMap {
|
||||
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
@@ -1125,13 +1126,26 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
|
||||
}
|
||||
|
||||
finalDNSConfig := *dnsConfig
|
||||
if finalDNSConfig.ServiceEnable && customZone.Domain != "" {
|
||||
if finalDNSConfig.ServiceEnable {
|
||||
var zones []nbdns.CustomZone
|
||||
records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers)
|
||||
|
||||
peerGroupsSlice := b.cache.peerToGroups[peer.ID]
|
||||
peerGroups := make(LookupMap, len(peerGroupsSlice))
|
||||
for _, groupID := range peerGroupsSlice {
|
||||
peerGroups[groupID] = struct{}{}
|
||||
}
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: customZone.Domain,
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
|
||||
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
zones = append(zones, filteredAccountZones...)
|
||||
|
||||
finalDNSConfig.CustomZones = zones
|
||||
}
|
||||
|
||||
|
||||
@@ -911,10 +911,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
accountUsers := []*types.User{}
|
||||
switch {
|
||||
case allowed:
|
||||
start := time.Now()
|
||||
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).Tracef("Got %d users from account %s after %s", len(accountUsers), accountID, time.Since(start))
|
||||
case user != nil && user.AccountID == accountID:
|
||||
accountUsers = append(accountUsers, user)
|
||||
default:
|
||||
@@ -933,23 +935,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||
usersFromIntegration := make([]*idp.UserData, 0)
|
||||
filtered := make(map[string]*idp.UserData, len(accountUsers))
|
||||
log.WithContext(ctx).Tracef("Querying users from IDP for account %s", accountID)
|
||||
start := time.Now()
|
||||
|
||||
integrationKeys := make(map[string]struct{})
|
||||
for _, user := range accountUsers {
|
||||
if user.Issued == types.UserIssuedIntegration {
|
||||
key := user.IntegrationReference.CacheKey(accountID, user.Id)
|
||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err)
|
||||
users[user.Id] = true
|
||||
continue
|
||||
}
|
||||
usersFromIntegration = append(usersFromIntegration, info)
|
||||
integrationKeys[user.IntegrationReference.CacheKey(accountID)] = struct{}{}
|
||||
continue
|
||||
}
|
||||
if !user.IsServiceUser {
|
||||
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
||||
}
|
||||
}
|
||||
|
||||
for key := range integrationKeys {
|
||||
usersData, err := am.externalCacheManager.GetUsers(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("GetUsers from ExternalCache for key: %s, error: %s", key, err)
|
||||
continue
|
||||
}
|
||||
for _, ud := range usersData {
|
||||
filtered[ud.ID] = ud
|
||||
}
|
||||
}
|
||||
|
||||
for _, ud := range filtered {
|
||||
usersFromIntegration = append(usersFromIntegration, ud)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("Got user info from external cache after %s", time.Since(start))
|
||||
start = time.Now()
|
||||
queriedUsers, err = am.lookupCache(ctx, users, accountID)
|
||||
log.WithContext(ctx).Tracef("Got user info from cache for %d users after %s", len(queriedUsers), time.Since(start))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1086,8 +1086,12 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cacheManager := am.GetExternalCacheManager()
|
||||
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}, time.Minute)
|
||||
tud := &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}
|
||||
cacheKeyUser := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||
err = cacheManager.Set(context.Background(), cacheKeyUser, tud, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
cacheKeyAccount := externalUser.IntegrationReference.CacheKey(mockAccountID)
|
||||
err = cacheManager.SetUsers(context.Background(), cacheKeyAccount, []*idp.UserData{tud}, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
package util
|
||||
|
||||
import "regexp"
|
||||
|
||||
var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
|
||||
|
||||
// Difference returns the elements in `a` that aren't in `b`.
|
||||
func Difference(a, b []string) []string {
|
||||
mb := make(map[string]struct{}, len(b))
|
||||
@@ -50,3 +54,10 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsValidDomain(domain string) bool {
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
return domainRegex.MatchString(domain)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/stun"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -43,6 +45,10 @@ type Config struct {
|
||||
LogLevel string
|
||||
LogFile string
|
||||
HealthcheckListenAddress string
|
||||
// STUN server configuration
|
||||
EnableSTUN bool
|
||||
STUNPorts []int
|
||||
STUNLogLevel string
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
@@ -52,6 +58,25 @@ func (c Config) Validate() error {
|
||||
if c.AuthSecret == "" {
|
||||
return fmt.Errorf("auth secret is required")
|
||||
}
|
||||
|
||||
// Validate STUN configuration
|
||||
if c.EnableSTUN {
|
||||
if len(c.STUNPorts) == 0 {
|
||||
return fmt.Errorf("--stun-ports is required when --enable-stun is set")
|
||||
}
|
||||
|
||||
seen := make(map[int]bool)
|
||||
for _, port := range c.STUNPorts {
|
||||
if port <= 0 || port > 65535 {
|
||||
return fmt.Errorf("invalid STUN port %d: must be between 1 and 65535", port)
|
||||
}
|
||||
if seen[port] {
|
||||
return fmt.Errorf("duplicate STUN port %d", port)
|
||||
}
|
||||
seen[port] = true
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -91,6 +116,9 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
|
||||
rootCmd.PersistentFlags().BoolVar(&cobraConfig.EnableSTUN, "enable-stun", false, "enable embedded STUN server")
|
||||
rootCmd.PersistentFlags().IntSliceVar(&cobraConfig.STUNPorts, "stun-ports", []int{3478}, "ports for the embedded STUN server (can be specified multiple times or comma-separated)")
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.STUNLogLevel, "stun-log-level", "info", "log level for STUN server (panic, fatal, error, warn, info, debug, trace)")
|
||||
|
||||
setFlagsFromEnvVars(rootCmd)
|
||||
}
|
||||
@@ -119,21 +147,14 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to initialize log: %s", err)
|
||||
}
|
||||
|
||||
// Resource creation phase (fail fast before starting any goroutines)
|
||||
|
||||
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
|
||||
if err != nil {
|
||||
log.Debugf("setup metrics: %v", err)
|
||||
return fmt.Errorf("setup metrics: %v", err)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Failed to start metrics server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
srvListenerCfg := server.ListenerConfig{
|
||||
Address: cobraConfig.ListenAddress,
|
||||
}
|
||||
@@ -145,6 +166,12 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
srvListenerCfg.TLSConfig = tlsConfig
|
||||
|
||||
// Create STUN listeners early to fail fast
|
||||
stunListeners, err := createSTUNListeners()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
||||
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||
|
||||
@@ -155,60 +182,145 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
TLSSupport: tlsSupport,
|
||||
}
|
||||
|
||||
srv, err := server.NewServer(cfg)
|
||||
srv, err := createRelayServer(cfg)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create relay server: %v", err)
|
||||
return fmt.Errorf("failed to create relay server: %v", err)
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
return err
|
||||
}
|
||||
|
||||
hCfg := healthcheck.Config{
|
||||
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
||||
ServiceChecker: srv,
|
||||
}
|
||||
httpHealthcheck, err := createHealthCheck(hCfg)
|
||||
if err != nil {
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
return err
|
||||
}
|
||||
|
||||
var stunServer *stun.Server
|
||||
if len(stunListeners) > 0 {
|
||||
stunServer = stun.NewServer(stunListeners, cobraConfig.STUNLogLevel)
|
||||
}
|
||||
|
||||
// Start all servers (only after all resources are successfully created)
|
||||
startServers(&wg, metricsServer, srv, srvListenerCfg, httpHealthcheck, stunServer)
|
||||
|
||||
waitForExitSignal()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = shutdownServers(ctx, metricsServer, srv, httpHealthcheck, stunServer)
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func startServers(wg *sync.WaitGroup, metricsServer *metrics.Metrics, srv *server.Server, srvListenerCfg server.ListenerConfig, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start metrics server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceURL := srv.InstanceURL()
|
||||
log.Infof("server will be available on: %s", instanceURL.String())
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := srv.Listen(srvListenerCfg); err != nil {
|
||||
log.Fatalf("failed to bind server: %s", err)
|
||||
log.Fatalf("failed to bind relay server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
hCfg := healthcheck.Config{
|
||||
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
||||
ServiceChecker: srv,
|
||||
}
|
||||
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create healthcheck server: %v", err)
|
||||
return fmt.Errorf("failed to create healthcheck server: %v", err)
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Failed to start healthcheck server: %v", err)
|
||||
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// it will block until exit signal
|
||||
waitForExitSignal()
|
||||
if stunServer != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := stunServer.Listen(); err != nil {
|
||||
if errors.Is(err, stun.ErrServerClosed) {
|
||||
return
|
||||
}
|
||||
log.Errorf("STUN server error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
func shutdownServers(ctx context.Context, metricsServer *metrics.Metrics, srv *server.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) error {
|
||||
var errs error
|
||||
|
||||
var shutDownErrors error
|
||||
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
|
||||
}
|
||||
|
||||
if stunServer != nil {
|
||||
if err := stunServer.Shutdown(); err != nil {
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
|
||||
}
|
||||
|
||||
log.Infof("shutting down metrics server")
|
||||
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return shutDownErrors
|
||||
return errs
|
||||
}
|
||||
|
||||
func createHealthCheck(hCfg healthcheck.Config) (*healthcheck.Server, error) {
|
||||
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create healthcheck server: %v", err)
|
||||
return nil, fmt.Errorf("failed to create healthcheck server: %v", err)
|
||||
}
|
||||
return httpHealthcheck, nil
|
||||
}
|
||||
|
||||
func createRelayServer(cfg server.Config) (*server.Server, error) {
|
||||
srv, err := server.NewServer(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create relay server: %v", err)
|
||||
}
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
|
||||
for _, l := range stunListeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func createSTUNListeners() ([]*net.UDPConn, error) {
|
||||
var stunListeners []*net.UDPConn
|
||||
if cobraConfig.EnableSTUN {
|
||||
for _, port := range cobraConfig.STUNPorts {
|
||||
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
// Close already opened listeners on failure
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
log.Debugf("failed to create STUN listener on port %d: %v", port, err)
|
||||
return nil, fmt.Errorf("failed to create STUN listener on port %d: %v", port, err)
|
||||
}
|
||||
stunListeners = append(stunListeners, listener)
|
||||
}
|
||||
}
|
||||
return stunListeners, nil
|
||||
}
|
||||
|
||||
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
|
||||
|
||||
@@ -8,6 +8,10 @@ pid="$(pgrep -x -f /usr/bin/netbird-ui || true)"
|
||||
if [ -n "${pid}" ]
|
||||
then
|
||||
uid="$(cat /proc/"${pid}"/loginuid)"
|
||||
# loginuid can be 4294967295 (-1) if not set, fall back to process uid
|
||||
if [ "${uid}" = "4294967295" ] || [ "${uid}" = "-1" ]; then
|
||||
uid="$(stat -c '%u' /proc/"${pid}")"
|
||||
fi
|
||||
username="$(id -nu "${uid}")"
|
||||
# Only re-run if it was already running
|
||||
pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user