mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-03 14:39:55 +00:00
Compare commits
8 Commits
feature/af
...
add-codeco
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca4bc9591a | ||
|
|
a78c96ddc6 | ||
|
|
f0ec44d30d | ||
|
|
3e61ccb162 | ||
|
|
a48c20d8d8 | ||
|
|
2b57a7d43b | ||
|
|
fa1e241aea | ||
|
|
e7c9182ff9 |
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@@ -45,4 +45,11 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,client
|
||||
|
||||
61
.github/workflows/golang-test-linux.yml
vendored
61
.github/workflows/golang-test-linux.yml
vendored
@@ -158,7 +158,16 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,client
|
||||
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -276,9 +285,17 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-exec 'sudo' \
|
||||
-exec 'sudo' -coverprofile=coverage.txt \
|
||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,relay
|
||||
|
||||
test_proxy:
|
||||
name: "Proxy / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -326,7 +343,15 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test -timeout 10m -p 1 ./proxy/...
|
||||
go test -timeout 10m -p 1 -coverprofile=coverage.txt ./proxy/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,proxy
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
@@ -377,9 +402,17 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-exec 'sudo' -coverprofile=coverage.txt \
|
||||
-timeout 10m ./signal/... ./shared/signal/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,signal
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -445,10 +478,18 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=devcert \
|
||||
go test -tags=devcert -coverprofile=coverage.txt \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 20m ./management/... ./shared/management/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,management
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [build-cache]
|
||||
@@ -687,6 +728,14 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=integration \
|
||||
go test -tags=integration -coverprofile=coverage.txt \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: integration,management
|
||||
|
||||
39
.github/workflows/proto-version-check.yml
vendored
39
.github/workflows/proto-version-check.yml
vendored
@@ -20,15 +20,30 @@ jobs:
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const modifiedPbFiles = files.filter(
|
||||
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
|
||||
);
|
||||
if (modifiedPbFiles.length === 0) {
|
||||
console.log('No modified .pb.go files to check');
|
||||
// Cover renamed .pb.go files in addition to plain edits.
|
||||
// Renamed entries land under the new path with previous_filename
|
||||
// pointing at the base-side name, so we read the base content
|
||||
// from the old path when present.
|
||||
const changedPbFiles = files
|
||||
.filter(f => (f.status === 'modified' || f.status === 'renamed')
|
||||
&& f.filename.endsWith('.pb.go'))
|
||||
.map(f => ({
|
||||
headPath: f.filename,
|
||||
basePath: f.previous_filename || f.filename,
|
||||
}));
|
||||
if (changedPbFiles.length === 0) {
|
||||
console.log('No modified or renamed .pb.go files to check');
|
||||
return;
|
||||
}
|
||||
|
||||
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
// Matches the generator version headers protoc writes at the top
|
||||
// of generated files:
|
||||
// // protoc v3.21.12
|
||||
// // protoc-gen-go v1.26.0
|
||||
// // - protoc-gen-go-grpc v1.6.1 (grpc files prefix with "- ")
|
||||
// The optional "- " prefix and the optional -gen-go / -gen-go-grpc
|
||||
// suffixes keep the *_grpc.pb.go headers in scope.
|
||||
const versionPattern = /^\s*\/\/\s+(?:-\s+)?protoc(?:-gen-go(?:-grpc)?)?\s+v[\d.]+/;
|
||||
const baseSha = context.payload.pull_request.base.sha;
|
||||
const headSha = context.payload.pull_request.head.sha;
|
||||
|
||||
@@ -55,20 +70,22 @@ jobs:
|
||||
}
|
||||
|
||||
const violations = [];
|
||||
for (const file of modifiedPbFiles) {
|
||||
for (const file of changedPbFiles) {
|
||||
const [base, head] = await Promise.all([
|
||||
getVersionHeader(file.filename, baseSha),
|
||||
getVersionHeader(file.filename, headSha),
|
||||
getVersionHeader(file.basePath, baseSha),
|
||||
getVersionHeader(file.headPath, headSha),
|
||||
]);
|
||||
if (!base.ok || !head.ok) {
|
||||
core.warning(
|
||||
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||
`Skipping ${file.headPath}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
||||
violations.push({
|
||||
file: file.filename,
|
||||
file: file.basePath === file.headPath
|
||||
? file.headPath
|
||||
: `${file.basePath} → ${file.headPath}`,
|
||||
base: base.lines,
|
||||
head: head.lines,
|
||||
});
|
||||
|
||||
@@ -12,7 +12,13 @@ var (
|
||||
Short: "Print the NetBird's client application version",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
cmd.Println(version.NetbirdVersion())
|
||||
out := version.NetbirdVersion()
|
||||
if version.IsDevelopmentVersion(out) {
|
||||
if commit := version.NetbirdCommit(); commit != "" {
|
||||
out += "-" + commit
|
||||
}
|
||||
}
|
||||
cmd.Println(out)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -362,6 +362,10 @@ func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload
|
||||
return 0
|
||||
}
|
||||
|
||||
if pc := f.endpoint.capture.Load(); pc != nil {
|
||||
(*pc).Offer(fullPacket, true)
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
@@ -346,6 +347,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
engineConfig.TempDir = mobileDependency.TempDir
|
||||
// Leave StateDir empty when there is no state path so a disk-backed
|
||||
// syncstore falls back to os.TempDir() instead of filepath.Dir("") == ".".
|
||||
if path != "" {
|
||||
engineConfig.StateDir = filepath.Dir(path)
|
||||
}
|
||||
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -56,6 +55,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/syncstore"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
@@ -148,6 +148,10 @@ type EngineConfig struct {
|
||||
|
||||
LogPath string
|
||||
TempDir string
|
||||
|
||||
// StateDir is the directory holding the state file. The sync response
|
||||
// (network map) is serialized here on platforms that persist it to disk.
|
||||
StateDir string
|
||||
}
|
||||
|
||||
// EngineServices holds the external service dependencies required by the Engine.
|
||||
@@ -226,10 +230,15 @@ type Engine struct {
|
||||
|
||||
afpacketCapture *capture.AFPacketCapture
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
// Sync response persistence (protected by syncRespMux).
|
||||
// syncStore is nil unless persistence has been enabled; its presence is
|
||||
// what marks persistence as active. The backend (disk or memory) is
|
||||
// selected per-platform; see the syncstore package. syncStoreDir is where
|
||||
// a disk-backed store serializes to.
|
||||
syncRespMux sync.RWMutex
|
||||
syncStore syncstore.Store
|
||||
syncStoreDir string
|
||||
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
@@ -292,6 +301,7 @@ func NewEngine(
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
syncStoreDir: config.StateDir,
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -913,19 +923,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// Read the storage-enabled flag under the syncRespMux too.
|
||||
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
|
||||
// the whole Set so the store cannot be cleared (disabled / engine close)
|
||||
// mid-call and have this write resurrect a file that was just removed.
|
||||
e.syncRespMux.RLock()
|
||||
enabled := e.persistSyncResponse
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
// Store sync response if persistence is enabled
|
||||
if enabled {
|
||||
e.syncRespMux.Lock()
|
||||
e.latestSyncResponse = update
|
||||
e.syncRespMux.Unlock()
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
if e.syncStore != nil {
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
} else {
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
}
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
@@ -1813,6 +1822,18 @@ func (e *Engine) close() {
|
||||
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
|
||||
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
|
||||
}
|
||||
|
||||
// Drop any persisted sync response so its network map does not linger on
|
||||
// disk after the engine stops (and cannot leak into a later run).
|
||||
e.syncRespMux.Lock()
|
||||
store := e.syncStore
|
||||
e.syncStore = nil
|
||||
e.syncRespMux.Unlock()
|
||||
if store != nil {
|
||||
if err := store.Clear(); err != nil {
|
||||
log.Warnf("failed to clear persisted sync response on close: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||
@@ -2142,45 +2163,42 @@ func (e *Engine) stopDNSServer() {
|
||||
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
||||
}
|
||||
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence.
|
||||
// The store is only instantiated while persistence is enabled; construction
|
||||
// itself drops any stale data left over from an earlier run (see syncstore).
|
||||
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||
e.syncRespMux.Lock()
|
||||
defer e.syncRespMux.Unlock()
|
||||
|
||||
if enabled == e.persistSyncResponse {
|
||||
if enabled == (e.syncStore != nil) {
|
||||
return
|
||||
}
|
||||
e.persistSyncResponse = enabled
|
||||
log.Debugf("Sync response persistence is set to %t", enabled)
|
||||
|
||||
if !enabled {
|
||||
e.latestSyncResponse = nil
|
||||
if err := e.syncStore.Clear(); err != nil {
|
||||
log.Warnf("failed to clear persisted sync response: %v", err)
|
||||
}
|
||||
e.syncStore = nil
|
||||
return
|
||||
}
|
||||
|
||||
e.syncStore = syncstore.New(e.syncStoreDir)
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
|
||||
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
// Hold the lock for the whole Get so the store cannot be cleared
|
||||
// (disabled / engine close) mid-call.
|
||||
e.syncRespMux.RLock()
|
||||
enabled := e.persistSyncResponse
|
||||
latest := e.latestSyncResponse
|
||||
e.syncRespMux.RUnlock()
|
||||
defer e.syncRespMux.RUnlock()
|
||||
|
||||
if !enabled {
|
||||
if e.syncStore == nil {
|
||||
return nil, errors.New("sync response persistence is disabled")
|
||||
}
|
||||
|
||||
if latest == nil {
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
|
||||
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to clone sync response")
|
||||
}
|
||||
|
||||
return sr, nil
|
||||
//nolint:nilnil
|
||||
return e.syncStore.Get()
|
||||
}
|
||||
|
||||
// GetWgAddr returns the wireguard address
|
||||
@@ -2216,7 +2234,7 @@ func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
) {
|
||||
if e.config.DisableServerRoutes {
|
||||
if e.config.DisableServerRoutes || e.config.BlockInbound {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -11,7 +13,7 @@ var (
|
||||
)
|
||||
|
||||
func IsSupported(agentVersion string) bool {
|
||||
if agentVersion == "development" {
|
||||
if nbversion.IsDevelopmentVersion(agentVersion) {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -310,8 +310,12 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Returns the
|
||||
// zero State and false when no peer matches or the input is empty.
|
||||
// address so dual-stack peers are reachable on either family. Searches
|
||||
// both d.peers and d.offlinePeers — peers that have been moved into
|
||||
// the offline slice by ReplaceOfflinePeers are still part of the
|
||||
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
|
||||
// need to recognise them rather than treating them as unknown. Returns
|
||||
// the zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
@@ -324,6 +328,11 @@ func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
for _, state := range d.offlinePeers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
|
||||
|
||||
@@ -90,6 +90,28 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
|
||||
// been moved into the offline slice via ReplaceOfflinePeers. Callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
|
||||
// rather than unknown — otherwise authentication / DNS filtering treats
|
||||
// known-but-offline peers as foreign IPs.
|
||||
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
status.ReplaceOfflinePeers([]State{
|
||||
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
|
||||
})
|
||||
|
||||
state, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
|
||||
|
||||
state, ok = status.PeerStateByIP("fd00::20")
|
||||
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
key := "abc"
|
||||
fqdn := "peer-a.netbird.local"
|
||||
|
||||
99
client/internal/syncstore/disk.go
Normal file
99
client/internal/syncstore/disk.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package syncstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// syncResponseFileName is the name of the file the sync response is serialized
|
||||
// to, placed inside the configured directory (the state directory).
|
||||
const syncResponseFileName = "networkmap.pb"
|
||||
|
||||
// diskStore serializes the latest sync response to a file on disk instead of
|
||||
// keeping it in memory. This trades disk I/O for a much smaller memory
|
||||
// footprint, which matters on memory-constrained platforms (iOS).
|
||||
type diskStore struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
}
|
||||
|
||||
// NewDiskStore returns a Store that serializes the sync response to a file in
|
||||
// the given directory. If dir is empty it falls back to the OS temp directory.
|
||||
//
|
||||
// Any file left over from a previous run is removed on construction so a fresh
|
||||
// store never reads stale data (e.g. another profile's network map).
|
||||
func NewDiskStore(dir string) Store {
|
||||
if dir == "" {
|
||||
dir = os.TempDir()
|
||||
}
|
||||
s := &diskStore{
|
||||
path: filepath.Join(dir, syncResponseFileName),
|
||||
}
|
||||
if err := s.Clear(); err != nil {
|
||||
log.Warnf("failed to clear stale sync response file: %v", err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *diskStore) Set(resp *mgmProto.SyncResponse) error {
|
||||
if resp == nil {
|
||||
return s.Clear()
|
||||
}
|
||||
|
||||
bs, err := proto.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sync response: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := util.WriteBytesWithRestrictedPermission(context.Background(), s.path, bs); err != nil {
|
||||
return fmt.Errorf("write sync response to %s: %w", s.path, err)
|
||||
}
|
||||
|
||||
log.Debugf("sync response persisted to %s (%d bytes)", s.path, len(bs))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *diskStore) Get() (*mgmProto.SyncResponse, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
bs, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
//nolint:nilnil // nil,nil means "nothing stored", per the Store contract; preserve the original behaviour
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("read sync response from %s: %w", s.path, err)
|
||||
}
|
||||
|
||||
resp := &mgmProto.SyncResponse{}
|
||||
if err := proto.Unmarshal(bs, resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal sync response: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("retrieving latest sync response from %s (%d bytes)", s.path, len(bs))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *diskStore) Clear() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := os.Remove(s.path); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("remove sync response file %s: %w", s.path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
9
client/internal/syncstore/factory_ios.go
Normal file
9
client/internal/syncstore/factory_ios.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build ios
|
||||
|
||||
package syncstore
|
||||
|
||||
// New returns the platform default store. On iOS the sync response is
|
||||
// serialized to disk (in dir) to keep it out of the constrained process memory.
|
||||
func New(dir string) Store {
|
||||
return NewDiskStore(dir)
|
||||
}
|
||||
9
client/internal/syncstore/factory_other.go
Normal file
9
client/internal/syncstore/factory_other.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !ios
|
||||
|
||||
package syncstore
|
||||
|
||||
// New returns the platform default store. On all non-iOS platforms the sync
|
||||
// response is kept in memory; dir is unused.
|
||||
func New(_ string) Store {
|
||||
return NewMemoryStore()
|
||||
}
|
||||
56
client/internal/syncstore/memory.go
Normal file
56
client/internal/syncstore/memory.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package syncstore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// memoryStore keeps the latest sync response in memory.
|
||||
type memoryStore struct {
|
||||
mu sync.RWMutex
|
||||
latest *mgmProto.SyncResponse
|
||||
}
|
||||
|
||||
// NewMemoryStore returns a Store that keeps the sync response in memory.
|
||||
func NewMemoryStore() Store {
|
||||
return &memoryStore{}
|
||||
}
|
||||
|
||||
func (s *memoryStore) Set(resp *mgmProto.SyncResponse) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.latest = resp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Get() (*mgmProto.SyncResponse, error) {
|
||||
s.mu.RLock()
|
||||
latest := s.latest
|
||||
s.mu.RUnlock()
|
||||
|
||||
if latest == nil {
|
||||
//nolint:nilnil // nil,nil means "nothing stored", per the Store contract; preserve the original behaviour
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Debugf("retrieving latest sync response with size %d bytes", proto.Size(latest))
|
||||
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("clone sync response")
|
||||
}
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Clear() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.latest = nil
|
||||
return nil
|
||||
}
|
||||
29
client/internal/syncstore/syncstore.go
Normal file
29
client/internal/syncstore/syncstore.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Package syncstore stores the latest Management sync response (which carries
|
||||
// the network map) for debug bundle generation.
|
||||
//
|
||||
// The storage backend is selected at build time per operating system: on iOS
|
||||
// the response is serialized to disk to keep it out of the (tightly
|
||||
// constrained) process memory, while on all other platforms it is kept in
|
||||
// memory. The backend is chosen by the New constructor; see factory_ios.go and
|
||||
// factory_other.go.
|
||||
package syncstore
|
||||
|
||||
import (
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Store persists the latest sync response and returns it on demand.
|
||||
//
|
||||
// Implementations must be safe for concurrent use.
|
||||
type Store interface {
|
||||
// Set stores the given sync response, replacing any previously stored one.
|
||||
Set(resp *mgmProto.SyncResponse) error
|
||||
|
||||
// Get returns the stored sync response, or nil if none is stored.
|
||||
// The returned value is an independent copy that the caller may retain.
|
||||
Get() (*mgmProto.SyncResponse, error)
|
||||
|
||||
// Clear removes any stored sync response. It is safe to call when nothing
|
||||
// is stored.
|
||||
Clear() error
|
||||
}
|
||||
@@ -19,8 +19,6 @@ import (
|
||||
|
||||
const (
|
||||
latestVersion = "latest"
|
||||
// this version will be ignored
|
||||
developmentVersion = "development"
|
||||
)
|
||||
|
||||
var errNoUpdateState = errors.New("no update state found")
|
||||
@@ -483,7 +481,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e
|
||||
}
|
||||
|
||||
func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool {
|
||||
if m.currentVersion == developmentVersion {
|
||||
if version.IsDevelopmentVersion(m.currentVersion) {
|
||||
log.Debugf("skipping auto-update, running development version")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
@@ -514,7 +515,7 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
for _, peer := range peers {
|
||||
|
||||
// Development version is always supported
|
||||
if peer.Meta.WtVersion == "development" {
|
||||
if version.IsDevelopmentVersion(peer.Meta.WtVersion) {
|
||||
continue
|
||||
}
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
|
||||
@@ -932,7 +932,11 @@ func (s *Service) validateL4Target(target *Target) error {
|
||||
if target.TargetId == "" {
|
||||
return errors.New("target_id is required for L4 services")
|
||||
}
|
||||
if target.TargetType != TargetTypeCluster && target.Port == 0 {
|
||||
// Cluster targets resolve their upstream host:port from the target's
|
||||
// own Host/Port fields just like the other L4 types — buildPathMappings
|
||||
// emits net.JoinHostPort(target.Host, target.Port) for every L4
|
||||
// target, so allowing port=0 here would let ":0" reach the proxy.
|
||||
if target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
switch target.TargetType {
|
||||
|
||||
@@ -1176,7 +1176,12 @@ func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
|
||||
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
|
||||
}
|
||||
|
||||
func TestValidate_L4ClusterTarget(t *testing.T) {
|
||||
// TestValidate_L4ClusterTarget_RequiresPort confirms that an L4 cluster
|
||||
// target without an explicit port is rejected. buildPathMappings emits
|
||||
// net.JoinHostPort(target.Host, target.Port) for every L4 target — so
|
||||
// allowing port=0 would let the proxy ship ":0" upstreams. The port
|
||||
// requirement is the same as every other L4 target type.
|
||||
func TestValidate_L4ClusterTarget_RequiresPort(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Mode = ModeTCP
|
||||
rp.ListenPort = 9000
|
||||
@@ -1186,7 +1191,12 @@ func TestValidate_L4ClusterTarget(t *testing.T) {
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
|
||||
assert.ErrorContains(t, rp.Validate(), "port is required",
|
||||
"L4 cluster target must require an explicit port like other L4 target types")
|
||||
|
||||
rp.Targets[0].Port = 5432
|
||||
rp.Targets[0].Host = "db.lan"
|
||||
require.NoError(t, rp.Validate(), "L4 cluster target with host:port must validate")
|
||||
}
|
||||
|
||||
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
|
||||
|
||||
@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
|
||||
|
||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, "", domain, auth.MethodOIDC, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
||||
@@ -394,6 +394,10 @@ func (m *testValidateSessionProxyManager) ClusterSupportsCrowdSec(_ context.Cont
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testValidateSessionUsersManager struct {
|
||||
store store.Store
|
||||
}
|
||||
@@ -401,3 +405,24 @@ type testValidateSessionUsersManager struct {
|
||||
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(user.AutoGroups) == 0 {
|
||||
return user, nil, nil
|
||||
}
|
||||
groupsMap, err := m.store.GetGroupsByIDs(ctx, store.LockingStrengthNone, user.AccountID, user.AutoGroups)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
groups := make([]*types.Group, 0, len(user.AutoGroups))
|
||||
for _, id := range user.AutoGroups {
|
||||
if g, ok := groupsMap[id]; ok && g != nil {
|
||||
groups = append(groups, g)
|
||||
}
|
||||
}
|
||||
return user, groups, nil
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const remoteJobsMinVer = "0.64.0"
|
||||
@@ -372,7 +373,7 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p
|
||||
}
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(remoteJobsMinVer, p.Meta.WtVersion)
|
||||
if !strings.Contains(p.Meta.WtVersion, "dev") && (!meetMinVer || err != nil) {
|
||||
if !version.IsDevelopmentVersion(p.Meta.WtVersion) && (!meetMinVer || err != nil) {
|
||||
return status.Errorf(status.PreconditionFailed, "peer version %s does not meet the minimum required version %s for remote jobs", p.Meta.WtVersion, remoteJobsMinVer)
|
||||
}
|
||||
|
||||
|
||||
@@ -4734,7 +4734,13 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
|
||||
result := tx.
|
||||
Take(&peer, fmt.Sprintf("account_id = ? AND %s = ?", column), accountID, jsonValue)
|
||||
if result.Error != nil {
|
||||
// no logging here
|
||||
// A tunnel-IP miss is an expected outcome (e.g. the proxy's
|
||||
// ValidateTunnelPeer probing an address that isn't in the
|
||||
// account roster); surface it as NotFound so callers can tell
|
||||
// it apart from a real store failure.
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer with ip %s not found", ip.String())
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get peer from store")
|
||||
}
|
||||
|
||||
@@ -5962,6 +5968,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
|
||||
}
|
||||
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
|
||||
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func TestSqlStore_GetAccount_PrivateServiceRoundtrip(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
if os.Getenv("CI") == "true" && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
|
||||
@@ -491,6 +491,27 @@ func Test_GetAccount(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestSqlStore_GetPeerByIP_NotFound pins the not-found semantics the
|
||||
// proxy's ValidateTunnelPeer relies on: a tunnel-IP that isn't in the
|
||||
// account roster must surface as a NotFound error (not a generic
|
||||
// Internal) so callers can distinguish an expected miss from a real
|
||||
// store failure. A known IP still resolves.
|
||||
func TestSqlStore_GetPeerByIP_NotFound(t *testing.T) {
|
||||
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
peer, err := store.GetPeerByIP(context.Background(), LockingStrengthNone, accountID, net.ParseIP("192.168.0.0"))
|
||||
require.NoError(t, err, "known tunnel IP must resolve")
|
||||
require.NotNil(t, peer)
|
||||
|
||||
_, err = store.GetPeerByIP(context.Background(), LockingStrengthNone, accountID, net.ParseIP("100.65.0.99"))
|
||||
require.Error(t, err, "unknown tunnel IP must error")
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok, "error must be a status error")
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "tunnel-IP miss must be NotFound, not Internal")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePeer(t *testing.T) {
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -1804,7 +1805,7 @@ func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *n
|
||||
|
||||
// peerSupportedFirewallFeatures checks if the peer version supports port ranges.
|
||||
func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
|
||||
if strings.Contains(peerVer, "dev") {
|
||||
if version.IsDevelopmentVersion(peerVer) {
|
||||
return supportedFeatures{true, true}
|
||||
}
|
||||
|
||||
|
||||
@@ -646,41 +646,7 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||
expectedPorts: []string{"20-25", "10-100", "22022"},
|
||||
},
|
||||
{
|
||||
name: "dev suffix version supports all features",
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
SSHEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.50.0-dev",
|
||||
Flags: nbpeer.Flags{ServerSSHAllowed: true},
|
||||
},
|
||||
},
|
||||
rule: &PolicyRule{
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"22"},
|
||||
},
|
||||
base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
|
||||
expectedPorts: []string{"22", "22022"},
|
||||
},
|
||||
{
|
||||
name: "dev suffix version supports all features",
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
SSHEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "dev",
|
||||
Flags: nbpeer.Flags{ServerSSHAllowed: true},
|
||||
},
|
||||
},
|
||||
rule: &PolicyRule{
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"22"},
|
||||
},
|
||||
base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
|
||||
expectedPorts: []string{"22", "22022"},
|
||||
},
|
||||
{
|
||||
name: "development suffix version supports all features",
|
||||
name: "development version supports all features",
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
SSHEnabled: true,
|
||||
|
||||
@@ -214,7 +214,10 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid --trusted-proxies: %w", err)
|
||||
}
|
||||
|
||||
srv := proxy.New(proxy.Config{
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer stop()
|
||||
|
||||
srv := proxy.New(ctx, proxy.Config{
|
||||
ListenAddr: addr,
|
||||
Logger: logger,
|
||||
Version: Version,
|
||||
@@ -251,9 +254,6 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
})
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer stop()
|
||||
|
||||
return srv.ListenAndServe(ctx, addr)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
stdlog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -42,7 +43,7 @@ const privateInboundPortHTTPS = 443
|
||||
const privateInboundPortHTTP = 80
|
||||
|
||||
// inboundManager wires per-account inbound listeners into the proxy
|
||||
// pipeline when --private-inbound is enabled. When disabled the manager
|
||||
// pipeline when --private is enabled. When disabled the manager
|
||||
// is nil and every method on *Server that touches it short-circuits.
|
||||
type inboundManager struct {
|
||||
logger *log.Logger
|
||||
@@ -55,15 +56,18 @@ type inboundManager struct {
|
||||
}
|
||||
|
||||
// inboundEntry owns the listeners, router and HTTP servers for a single
|
||||
// account's embedded netstack.
|
||||
// account's embedded netstack. errorLogWriters retain the logrus pipe
|
||||
// writers backing each http.Server's ErrorLog so tearDown can close
|
||||
// them — otherwise the pipe + its scanner goroutine leak per account.
|
||||
type inboundEntry struct {
|
||||
router *nbtcp.Router
|
||||
tlsListener net.Listener
|
||||
plainListener net.Listener
|
||||
httpsServer *http.Server
|
||||
httpServer *http.Server
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
router *nbtcp.Router
|
||||
tlsListener net.Listener
|
||||
plainListener net.Listener
|
||||
httpsServer *http.Server
|
||||
httpServer *http.Server
|
||||
errorLogWriters []*io.PipeWriter
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// pendingInboundRoute holds a route that arrived before the account's
|
||||
@@ -147,30 +151,34 @@ func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID,
|
||||
return types.WithOverlayOrigin(ctx)
|
||||
}
|
||||
|
||||
httpsErrLog, httpsErrW := newInboundErrorLog(m.logger, "https", accountID)
|
||||
httpErrLog, httpErrW := newInboundErrorLog(m.logger, "http", accountID)
|
||||
|
||||
httpsServer := &http.Server{
|
||||
Handler: scopedHandler,
|
||||
TLSConfig: m.tlsConfig,
|
||||
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
|
||||
IdleTimeout: httpInboundIdleTimeout,
|
||||
ErrorLog: newInboundErrorLog(m.logger, "https", accountID),
|
||||
ErrorLog: httpsErrLog,
|
||||
ConnContext: markOverlayOrigin,
|
||||
}
|
||||
httpServer := &http.Server{
|
||||
Handler: scopedHandler,
|
||||
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
|
||||
IdleTimeout: httpInboundIdleTimeout,
|
||||
ErrorLog: newInboundErrorLog(m.logger, "http", accountID),
|
||||
ErrorLog: httpErrLog,
|
||||
ConnContext: markOverlayOrigin,
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
entry := &inboundEntry{
|
||||
router: router,
|
||||
tlsListener: tlsListener,
|
||||
plainListener: plainListener,
|
||||
httpsServer: httpsServer,
|
||||
httpServer: httpServer,
|
||||
cancel: cancel,
|
||||
router: router,
|
||||
tlsListener: tlsListener,
|
||||
plainListener: plainListener,
|
||||
httpsServer: httpsServer,
|
||||
httpServer: httpServer,
|
||||
errorLogWriters: []*io.PipeWriter{httpsErrW, httpErrW},
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
entry.wg.Add(1)
|
||||
@@ -237,6 +245,14 @@ func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry
|
||||
m.logger.Debugf("close per-account plain listener: %v", err)
|
||||
}
|
||||
entry.wg.Wait()
|
||||
// Close the ErrorLog pipes only after the http.Servers have fully
|
||||
// stopped so any straggling stdlib write doesn't race with the
|
||||
// close. Each writer also tears down the logrus scanner goroutine.
|
||||
for _, w := range entry.errorLogWriters {
|
||||
if err := w.Close(); err != nil {
|
||||
m.logger.Debugf("close per-account inbound error log writer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddRoute records an SNI/host route on the account's per-account router.
|
||||
@@ -374,7 +390,7 @@ func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListene
|
||||
}
|
||||
|
||||
// Snapshot returns the inbound listener state for every account that has
|
||||
// a live listener at call time. Empty when --private-inbound is off or
|
||||
// a live listener at call time. Empty when --private is off or
|
||||
// no accounts have come up yet.
|
||||
func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo {
|
||||
if m == nil {
|
||||
@@ -497,7 +513,7 @@ func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc {
|
||||
// peerstore lookup to every request's context before delegating to next.
|
||||
// Calling on the host-level listener is a no-op because that path never
|
||||
// installs this wrapper, so the existing behaviour stays byte-for-byte
|
||||
// identical when --private-inbound is off or the request didn't arrive
|
||||
// identical when --private is off or the request didn't arrive
|
||||
// on a per-account listener.
|
||||
func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler {
|
||||
if lookup == nil {
|
||||
@@ -538,10 +554,14 @@ func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.Inboun
|
||||
}
|
||||
|
||||
// newInboundErrorLog routes a per-account http.Server's stdlib error
|
||||
// stream through logrus at warn level.
|
||||
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger {
|
||||
return stdlog.New(logger.WithFields(log.Fields{
|
||||
// stream through logrus at warn level. The returned PipeWriter must be
|
||||
// closed by the caller (tearDown) once the http.Server has shut down —
|
||||
// otherwise the pipe and its scanner goroutine leak per account, see
|
||||
// logrus.Entry.WriterLevel.
|
||||
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) (*stdlog.Logger, *io.PipeWriter) {
|
||||
w := logger.WithFields(log.Fields{
|
||||
"inbound-http": scheme,
|
||||
"account_id": accountID,
|
||||
}).WriterLevel(log.WarnLevel), "", 0)
|
||||
}).WriterLevel(log.WarnLevel)
|
||||
return stdlog.New(w, "", 0), w
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -110,7 +111,7 @@ func TestServer_PrivateInbound_Enabled_WiresLifecycle(t *testing.T) {
|
||||
// Construct a NetBird transport. We can't actually start the embedded
|
||||
// client here (that needs a real management server), but we can
|
||||
// confirm that the lifecycle callbacks are registered.
|
||||
s.netbird = roundtrip.NewNetBird("test", "test", roundtrip.ClientConfig{
|
||||
s.netbird = roundtrip.NewNetBird(t.Context(), "test", "test", roundtrip.ClientConfig{
|
||||
MgmtAddr: "http://invalid.test",
|
||||
}, quietLogger(), nil, fakeMgmtClient{})
|
||||
|
||||
@@ -139,7 +140,7 @@ func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) {
|
||||
|
||||
// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability
|
||||
// bit reported upstream tracks --private exclusively. The previous
|
||||
// --private-inbound flag has been folded into --private.
|
||||
// --private flag has been folded into --private.
|
||||
func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -318,7 +319,7 @@ func TestInboundManager_ListenerInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestInboundManager_NilManagerSafe ensures the observability accessors
|
||||
// are safe to call when --private-inbound is off (nil manager).
|
||||
// are safe to call when --private is off (nil manager).
|
||||
func TestInboundManager_NilManagerSafe(t *testing.T) {
|
||||
var mgr *inboundManager
|
||||
_, ok := mgr.ListenerInfo("anything")
|
||||
@@ -482,6 +483,38 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config {
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec
|
||||
}
|
||||
|
||||
// TestNewInboundErrorLog_WriterIsCloseable guards the close path on the
|
||||
// logrus PipeWriter that backs each per-account http.Server's ErrorLog.
|
||||
// logrus.Entry.WriterLevel returns an *io.PipeWriter that owns a pipe +
|
||||
// scanner goroutine; the caller must Close() it on teardown or the
|
||||
// resources leak per account. The contract is verified two ways:
|
||||
//
|
||||
// - the constructor returns a non-nil writer the caller can keep,
|
||||
// - writing to the writer after Close() fails with io.ErrClosedPipe,
|
||||
// which is the only externally observable sign that Close was wired.
|
||||
//
|
||||
// A leaking refactor (forgetting to thread the writer to tearDown, or
|
||||
// dropping the Close call) would still pass this test individually but
|
||||
// fail an integration goleak check; this unit test is the cheap first
|
||||
// line of defence.
|
||||
func TestNewInboundErrorLog_WriterIsCloseable(t *testing.T) {
|
||||
logger := quietLogger()
|
||||
stdLog, writer := newInboundErrorLog(logger, "https", types.AccountID("acct-1"))
|
||||
|
||||
require.NotNil(t, stdLog, "newInboundErrorLog must return a non-nil *log.Logger")
|
||||
require.NotNil(t, writer, "newInboundErrorLog must return the underlying PipeWriter so tearDown can Close it")
|
||||
|
||||
// First Close succeeds.
|
||||
require.NoError(t, writer.Close(), "PipeWriter.Close should succeed the first time")
|
||||
|
||||
// After Close, the writer must refuse new writes — that's the only
|
||||
// behavioural signal that the pipe (and its scanner goroutine) has
|
||||
// shut down.
|
||||
_, err := writer.Write([]byte("post-close write\n"))
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe,
|
||||
"writes after Close must surface io.ErrClosedPipe so callers know the writer is gone")
|
||||
}
|
||||
|
||||
// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for
|
||||
// 127.0.0.1 — only used by tests that need a working TLS handshake.
|
||||
var testCertPEM = []byte(`-----BEGIN CERTIFICATE-----
|
||||
|
||||
@@ -346,13 +346,15 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
|
||||
// management unreachable, peer unknown, user not in group) returns false so
|
||||
// the caller falls back to the existing OIDC scheme dispatch.
|
||||
//
|
||||
// Phase 3 adds a local-first short-circuit: when the request arrived on a
|
||||
// per-account inbound listener the context carries a peerstore lookup
|
||||
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
|
||||
// roster the proxy denies fast without calling management. If the lookup
|
||||
// confirms a known peer the RPC still runs for the user-identity tail
|
||||
// (UserID + group access), but its result is cached for tunnelCacheTTL so
|
||||
// repeat requests skip management entirely.
|
||||
// The fast-path is gated on TunnelLookupFromContext(r.Context()) being
|
||||
// present — that context value is attached only by the per-account
|
||||
// inbound (overlay) listener. The host listener never sets it, so a
|
||||
// public client whose source IP happens to fall inside an RFC1918 / ULA
|
||||
// / CGNAT range can't impersonate a mesh peer by colliding with a
|
||||
// tunnel-IP. Once we know the request arrived over WireGuard the
|
||||
// per-account peerstore lookup is consulted: a miss denies fast (no
|
||||
// management round-trip), a hit gates the cached ValidateTunnelPeer RPC
|
||||
// that mints the session JWT.
|
||||
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
if mw.sessionValidator == nil {
|
||||
return false
|
||||
@@ -361,18 +363,24 @@ func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Reque
|
||||
if !clientIP.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Anti-spoof: only honour the tunnel-peer fast-path on requests that
|
||||
// were stamped by an overlay listener. Without that marker an
|
||||
// attacker could send a request from a colliding RFC1918 / CGNAT
|
||||
// source on the public listener and bypass operator auth.
|
||||
lookup := TunnelLookupFromContext(r.Context())
|
||||
if lookup == nil {
|
||||
return false
|
||||
}
|
||||
if !isTunnelSourceIP(clientIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
|
||||
if _, ok := lookup(clientIP); !ok {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"remote": clientIP,
|
||||
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
|
||||
return false
|
||||
}
|
||||
if _, ok := lookup(clientIP); !ok {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"remote": clientIP,
|
||||
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
|
||||
return false
|
||||
}
|
||||
|
||||
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{
|
||||
|
||||
@@ -1227,3 +1227,93 @@ func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
|
||||
}
|
||||
|
||||
// stubTunnelValidator records ValidateTunnelPeer calls so a test can
|
||||
// assert whether the fast-path reached management.
|
||||
type stubTunnelValidator struct {
|
||||
called bool
|
||||
resp *proto.ValidateTunnelPeerResponse
|
||||
}
|
||||
|
||||
func (s *stubTunnelValidator) ValidateSession(context.Context, *proto.ValidateSessionRequest, ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
|
||||
return nil, errors.New("not used in this test")
|
||||
}
|
||||
|
||||
func (s *stubTunnelValidator) ValidateTunnelPeer(context.Context, *proto.ValidateTunnelPeerRequest, ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
s.called = true
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
// TestProtect_TunnelPeerFastPath_RequiresInboundMarker guards the
|
||||
// anti-spoof gate: a request with an RFC1918 source IP arriving on the
|
||||
// public listener (no TunnelLookupFromContext attached) must not be
|
||||
// allowed to take the tunnel-peer fast-path. Without this gate a public
|
||||
// client whose source IP happens to fall inside an RFC1918 range could
|
||||
// bypass the configured auth scheme by colliding with a known tunnel
|
||||
// IP.
|
||||
func TestProtect_TunnelPeerFastPath_RequiresInboundMarker(t *testing.T) {
|
||||
validator := &stubTunnelValidator{
|
||||
resp: &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
SessionToken: "should-not-be-used",
|
||||
UserId: "user-1",
|
||||
},
|
||||
}
|
||||
mw := NewMiddleware(log.StandardLogger(), validator, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// Request from an RFC1918 source IP on the public listener — no
|
||||
// TunnelLookupFromContext attached. The fast-path must reject this
|
||||
// and fall through to the PIN scheme (which renders 401 on plain
|
||||
// HTTP for a non-authenticated request).
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "100.64.0.5:5000"
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, validator.called,
|
||||
"ValidateTunnelPeer must not be invoked when the request lacks the inbound TunnelLookup marker")
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code,
|
||||
"without the inbound marker the request must fall through to the operator auth scheme")
|
||||
}
|
||||
|
||||
// TestProtect_TunnelPeerFastPath_TakesPathWithInboundMarker verifies
|
||||
// the positive side: a request marked as overlay-origin (carrying the
|
||||
// TunnelLookup context value) and matching a tunnel-IP range does take
|
||||
// the fast-path and reach management.
|
||||
func TestProtect_TunnelPeerFastPath_TakesPathWithInboundMarker(t *testing.T) {
|
||||
validator := &stubTunnelValidator{
|
||||
resp: &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
SessionToken: "tunnel-session-token",
|
||||
UserId: "user-1",
|
||||
},
|
||||
}
|
||||
mw := NewMiddleware(log.StandardLogger(), validator, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, true
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "100.64.0.5:5000"
|
||||
req = req.WithContext(WithTunnelLookup(req.Context(), lookup))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.True(t, validator.called,
|
||||
"ValidateTunnelPeer must run when the request carries the inbound TunnelLookup marker")
|
||||
assert.Equal(t, http.StatusOK, rec.Code,
|
||||
"a successful tunnel-peer validation must forward to the next handler")
|
||||
}
|
||||
|
||||
@@ -101,7 +101,10 @@ func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) {
|
||||
|
||||
w, r := newTunnelRequest("100.64.0.10:55555")
|
||||
cd := proxy.NewCapturedData("")
|
||||
r = r.WithContext(proxy.WithCapturedData(r.Context(), cd))
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, true
|
||||
})
|
||||
r = r.WithContext(proxy.WithCapturedData(WithTunnelLookup(r.Context(), lookup), cd))
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
@@ -148,9 +151,13 @@ func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) {
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing
|
||||
// behaviour stays intact on the host-level listener (no lookup attached).
|
||||
func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
|
||||
// TestForwardWithTunnelPeer_NoLookupRefusesFastPath guards the
|
||||
// anti-spoof gate: requests that didn't arrive on the per-account
|
||||
// inbound listener (no TunnelLookup attached) must never reach
|
||||
// management's ValidateTunnelPeer, even when the source IP looks like
|
||||
// a tunnel address. A colliding RFC1918 / CGNAT source on the public
|
||||
// listener would otherwise impersonate a mesh peer.
|
||||
func TestForwardWithTunnelPeer_NoLookupRefusesFastPath(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
|
||||
@@ -165,9 +172,9 @@ func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
assert.True(t, handled, "host-level path forwards on positive RPC result")
|
||||
assert.True(t, called, "next handler runs on host-level success")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)")
|
||||
assert.False(t, handled, "fast-path must refuse without the inbound marker")
|
||||
assert.False(t, called, "next handler must not run")
|
||||
assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must not be invoked without the inbound marker")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC
|
||||
@@ -201,8 +208,13 @@ func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) {
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, true
|
||||
})
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
w, r := newTunnelRequest("100.64.0.10:55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
@@ -226,11 +238,21 @@ func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) {
|
||||
require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false))
|
||||
require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false))
|
||||
|
||||
// The fast-path requires the inbound-listener marker on the context.
|
||||
// The peerstore lookup itself is account-agnostic at this level
|
||||
// (one TunnelLookupFunc per account is attached by inbound.go); a
|
||||
// trivial "always hit" lookup is enough to exercise the cache-key
|
||||
// branch this test covers.
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, true
|
||||
})
|
||||
|
||||
for _, host := range []string{"svc-a.example", "svc-b.example"} {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil)
|
||||
r.Host = host
|
||||
r.RemoteAddr = "100.64.0.10:55555"
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
config, _ := mw.getDomainConfig(host)
|
||||
handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
require.True(t, handled, "host %s should forward", host)
|
||||
@@ -314,9 +336,17 @@ func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Per-account inbound listener attaches WithTunnelLookup; without it
|
||||
// forwardWithTunnelPeer refuses to take the fast-path. Mirror the
|
||||
// real flow so this test exercises the post-gating success branch.
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, true
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
|
||||
req.Host = "private.svc"
|
||||
req.RemoteAddr = "100.64.0.10:55555"
|
||||
req = req.WithContext(WithTunnelLookup(req.Context(), lookup))
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ func (h *Handler) SetCertStatus(cs certStatus) {
|
||||
|
||||
// SetInboundProvider wires per-account inbound listener observability.
|
||||
// Pass nil (or skip the call) to keep the inbound section out of debug
|
||||
// responses on proxies that don't run --private-inbound.
|
||||
// responses on proxies that don't run --private.
|
||||
func (h *Handler) SetInboundProvider(p InboundProvider) {
|
||||
h.inbound = p
|
||||
}
|
||||
|
||||
@@ -66,6 +66,22 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Loop guard for private services: a peer that hosts the target
|
||||
// dialing its own service URL would round-trip its own traffic
|
||||
// through the proxy and back over WG to itself. Refuse the request
|
||||
// with 421 (Misdirected Request) so the caller sees an explicit
|
||||
// error instead of silently doubling tunnel traffic.
|
||||
if p.isSelfTargetLoop(r, result.target.URL) {
|
||||
if cd := CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(OriginNoRoute)
|
||||
}
|
||||
requestID := getRequestID(r)
|
||||
web.ServeErrorPage(w, r, http.StatusMisdirectedRequest, "Loop Detected",
|
||||
"This peer is the target of the requested service. Reach the backend directly instead of dialing the public service URL from the same machine.",
|
||||
requestID, web.ErrorStatus{Proxy: true, Destination: false})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
// Set the account ID in the context for the roundtripper to use.
|
||||
ctx = roundtrip.WithAccountID(ctx, result.accountID)
|
||||
@@ -107,6 +123,32 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
// isSelfTargetLoop reports whether an overlay-origin request is about to
|
||||
// be forwarded back to the very peer that initiated it. The detection
|
||||
// is intentionally narrow: it only fires when the request arrived on
|
||||
// the per-account inbound (overlay) listener (so we're confident the
|
||||
// source address is the caller's tunnel IP), and only when the resolved
|
||||
// target host matches that tunnel IP. Catching this here returns 421 to
|
||||
// the caller instead of letting the proxy round-trip its own traffic
|
||||
// over WG twice.
|
||||
func (p *ReverseProxy) isSelfTargetLoop(r *http.Request, target *url.URL) bool {
|
||||
if target == nil {
|
||||
return false
|
||||
}
|
||||
if !types.IsOverlayOrigin(r.Context()) {
|
||||
return false
|
||||
}
|
||||
srcIP := extractHostIP(r.RemoteAddr)
|
||||
if !srcIP.IsValid() {
|
||||
return false
|
||||
}
|
||||
targetIP, err := netip.ParseAddr(target.Hostname())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return srcIP.Unmap() == targetIP.Unmap()
|
||||
}
|
||||
|
||||
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
|
||||
// inbound requests to target the backend service while setting security-relevant
|
||||
// forwarding headers and stripping proxy authentication credentials.
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
@@ -1285,6 +1286,103 @@ func TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid(t *testing.T) {
|
||||
"X-NetBird-Groups must not be set when every group label is rejected")
|
||||
}
|
||||
|
||||
// nopOKTransport returns 200 for every request without dialing — used
|
||||
// by the self-target-loop tests so the non-loop cases don't pay a real
|
||||
// TCP-dial timeout.
|
||||
type nopOKTransport struct{}
|
||||
|
||||
func (nopOKTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody, Header: http.Header{}}, nil
|
||||
}
|
||||
|
||||
// TestServeHTTP_SelfTargetLoopReturns421 covers the loop guard for
|
||||
// private services: when a peer dials a service whose only target is
|
||||
// the peer itself, the proxy must refuse with 421 (Misdirected
|
||||
// Request) rather than round-tripping the request back over WG to
|
||||
// the same peer.
|
||||
func TestServeHTTP_SelfTargetLoopReturns421(t *testing.T) {
|
||||
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
|
||||
rp.AddMapping(Mapping{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Host: "private.svc",
|
||||
Paths: map[string]*PathTarget{
|
||||
"/": {
|
||||
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://private.svc/", nil)
|
||||
req.Host = "private.svc"
|
||||
req.RemoteAddr = "100.64.0.5:55555"
|
||||
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
rp.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusMisdirectedRequest, rec.Code,
|
||||
"a peer dialing a service whose target is itself must get 421")
|
||||
}
|
||||
|
||||
// TestServeHTTP_SelfTargetLoop_NonOverlayRequestPassesThrough verifies
|
||||
// the guard is scoped to overlay-origin requests. A public-listener
|
||||
// request that happens to share a source IP with the target host must
|
||||
// not be misinterpreted as a loop — the gating relies on the inbound
|
||||
// marker being attached only by the per-account overlay listener.
|
||||
func TestServeHTTP_SelfTargetLoop_NonOverlayRequestPassesThrough(t *testing.T) {
|
||||
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
|
||||
rp.AddMapping(Mapping{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Host: "public.svc",
|
||||
Paths: map[string]*PathTarget{
|
||||
"/": {
|
||||
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://public.svc/", nil)
|
||||
req.Host = "public.svc"
|
||||
req.RemoteAddr = "100.64.0.5:55555"
|
||||
// No WithOverlayOrigin → the guard must not fire.
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
rp.ServeHTTP(rec, req)
|
||||
|
||||
assert.NotEqual(t, http.StatusMisdirectedRequest, rec.Code,
|
||||
"a non-overlay request with a colliding source IP must not be flagged as a loop")
|
||||
}
|
||||
|
||||
// TestServeHTTP_SelfTargetLoop_OverlayDifferentIPPassesThrough confirms
|
||||
// that overlay-origin requests with a source IP that does *not* match
|
||||
// the target host are forwarded normally.
|
||||
func TestServeHTTP_SelfTargetLoop_OverlayDifferentIPPassesThrough(t *testing.T) {
|
||||
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
|
||||
rp.AddMapping(Mapping{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Host: "private.svc",
|
||||
Paths: map[string]*PathTarget{
|
||||
"/": {
|
||||
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://private.svc/", nil)
|
||||
req.Host = "private.svc"
|
||||
req.RemoteAddr = "100.64.0.99:55555" // different from the target
|
||||
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
rp.ServeHTTP(rec, req)
|
||||
|
||||
assert.NotEqual(t, http.StatusMisdirectedRequest, rec.Code,
|
||||
"overlay request with a non-matching source IP must not be flagged as a loop")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_CapturedDataPresentButEmpty covers requests
|
||||
// that carry CapturedData with no identity fields populated (e.g. the
|
||||
// auth middleware ran but the request didn't authenticate). Both
|
||||
|
||||
@@ -152,6 +152,7 @@ type managementClient interface {
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple services to share the same connection.
|
||||
type NetBird struct {
|
||||
ctx context.Context
|
||||
proxyID string
|
||||
proxyAddr string
|
||||
clientCfg ClientConfig
|
||||
@@ -213,7 +214,11 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
|
||||
// Use a background context, not the caller's: the management
|
||||
// connection notification must land even if the request /
|
||||
// stream that triggered this registration is cancelled.
|
||||
// Mirrors the async runClientStartup path.
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
@@ -242,8 +247,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip.
|
||||
go n.runClientStartup(ctx, accountID, entry.client)
|
||||
// retry on the first request via RoundTrip. runClientStartup uses its
|
||||
// own background context so the caller's request-scoped ctx can't
|
||||
// cancel the inbound bring-up.
|
||||
go n.runClientStartup(accountID, entry.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -307,7 +314,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
ManagementURL: n.clientCfg.MgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: n.clientCfg.BlockInbound,
|
||||
BlockInbound: n.clientCfg.BlockInbound,
|
||||
// The embedded proxy peer must never be a stepping stone into
|
||||
// the proxy host's LAN: it only exists to reach NetBird mesh
|
||||
// targets or, when direct_upstream is set, the host network
|
||||
@@ -355,8 +362,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runClientStartup starts the client and notifies registered services on success.
|
||||
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
|
||||
// runClientStartup starts the client and notifies registered services on
|
||||
// success. This function runs in a goroutine launched from AddPeer, so it
|
||||
// must never inherit the caller's request-scoped context — a canceled
|
||||
// request must not abort the inbound listener bring-up or the management
|
||||
// status notification. The embedded client.Start gets its own bounded
|
||||
// startCtx; once Start succeeds, notifyClientReady takes over with a
|
||||
// fresh context.Background() (see that function for the contract).
|
||||
func (n *NetBird) runClientStartup(accountID types.AccountID, client *embed.Client) {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -369,7 +382,17 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
return
|
||||
}
|
||||
|
||||
// Mark client as started and collect services to notify outside the lock.
|
||||
n.notifyClientReady(accountID, client)
|
||||
}
|
||||
|
||||
// notifyClientReady marks the account's client as started, fires the
|
||||
// readyHandler hook, and notifies management of the new tunnel
|
||||
// connection for every registered service. It is split out of
|
||||
// runClientStartup so a regression test can drive the post-Start tail
|
||||
// without needing a live embedded client. The contract that the
|
||||
// hooks/notifier see context.Background() — never the AddPeer caller's
|
||||
// ctx — lives here.
|
||||
func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Client) {
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
@@ -385,7 +408,7 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
if readyHandler != nil {
|
||||
state := readyHandler(ctx, accountID, client)
|
||||
state := readyHandler(n.ctx, accountID, client)
|
||||
n.clientsMux.Lock()
|
||||
if e, ok := n.clients[accountID]; ok {
|
||||
e.inbound = state
|
||||
@@ -404,7 +427,7 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
return
|
||||
}
|
||||
for _, sn := range toNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(n.ctx, accountID, sn.serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": sn.key,
|
||||
@@ -666,11 +689,12 @@ func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
|
||||
// NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random
|
||||
// OS-assigned port. A fixed port only works with single-account deployments;
|
||||
// multiple accounts will fail to bind the same port.
|
||||
func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
|
||||
func NewNetBird(ctx context.Context, proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &NetBird{
|
||||
ctx: ctx,
|
||||
proxyID: proxyID,
|
||||
proxyAddr: proxyAddr,
|
||||
clientCfg: clientCfg,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -30,12 +31,15 @@ type statusCall struct {
|
||||
accountID types.AccountID
|
||||
serviceID types.ServiceID
|
||||
connected bool
|
||||
// ctx is captured so tests can assert the notifier received a
|
||||
// fresh background context rather than an inherited request ctx.
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
|
||||
func (m *mockStatusNotifier) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected})
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected, ctx})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -48,7 +52,7 @@ func (m *mockStatusNotifier) calls() []statusCall {
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
return NewNetBird("test-proxy", "invalid.test", ClientConfig{
|
||||
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
@@ -279,7 +283,7 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
||||
|
||||
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
@@ -295,8 +299,12 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
nb.clients[accountID].started = true
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// Add second service — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
|
||||
// Add second service with an already-cancelled caller context —
|
||||
// should notify immediately (client is started) AND the notification
|
||||
// must not inherit the cancelled ctx.
|
||||
cancelledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
err = nb.AddPeer(cancelledCtx, accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := notifier.calls()
|
||||
@@ -304,6 +312,9 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
assert.Equal(t, accountID, calls[0].accountID)
|
||||
assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID)
|
||||
assert.True(t, calls[0].connected)
|
||||
require.NotNil(t, calls[0].ctx, "NotifyStatus must receive a context")
|
||||
require.NoError(t, calls[0].ctx.Err(),
|
||||
"already-started NotifyStatus must use a background ctx, not the cancelled caller ctx")
|
||||
}
|
||||
|
||||
// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the
|
||||
@@ -338,7 +349,7 @@ func TestClientEntry_IdentityForIP_InvalidIPReturnsFalse(t *testing.T) {
|
||||
|
||||
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
@@ -360,3 +371,53 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID)
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
|
||||
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
|
||||
// a fresh context.Background() rather than inheriting the AddPeer
|
||||
// caller's request- or stream-scoped ctx. Without this, a cancelled
|
||||
// caller ctx could abort the inbound listener bring-up or cause the
|
||||
// management status notification to fail spuriously and leave the
|
||||
// account in a half-connected state.
|
||||
func TestNotifyClientReady_UsesBackgroundCtx(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
|
||||
accountID := types.AccountID("acct-async")
|
||||
// Pre-populate a client entry so notifyClientReady has something
|
||||
// to mark started + something to enumerate for NotifyStatus.
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{
|
||||
DomainServiceKey("svc.example"): {serviceID: types.ServiceID("svc-1")},
|
||||
},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
var capturedReadyCtx context.Context
|
||||
nb.SetClientLifecycle(
|
||||
func(ctx context.Context, _ types.AccountID, _ *embed.Client) any {
|
||||
capturedReadyCtx = ctx
|
||||
return nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Drive the post-Start path directly; a real client.Start would
|
||||
// need a working management URL.
|
||||
nb.notifyClientReady(accountID, nil)
|
||||
|
||||
require.NotNil(t, capturedReadyCtx, "readyHandler must have been invoked")
|
||||
require.NoError(t, capturedReadyCtx.Err(),
|
||||
"readyHandler must receive a background context, not an inherited cancelled one")
|
||||
deadline, ok := capturedReadyCtx.Deadline()
|
||||
assert.False(t, ok, "readyHandler ctx must have no deadline (background); got %v", deadline)
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1, "NotifyStatus must be invoked once per registered service")
|
||||
require.NotNil(t, calls[0].ctx, "NotifyStatus must receive a context")
|
||||
require.NoError(t, calls[0].ctx.Err(),
|
||||
"NotifyStatus must receive a background context, not an inherited cancelled one")
|
||||
}
|
||||
|
||||
@@ -1781,11 +1781,14 @@ func TestRouter_PlainHTTP_RoutesToPlainChannel(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
tlsListener, ok := router.HTTPListener().(*chanListener)
|
||||
require.True(t, ok, "router.HTTPListener() must be the test's chanListener; the test relies on observing its channel directly")
|
||||
|
||||
select {
|
||||
case conn := <-acceptDone:
|
||||
require.NotNil(t, conn)
|
||||
_ = conn.Close()
|
||||
case <-router.HTTPListener().(*chanListener).ch:
|
||||
case <-tlsListener.ch:
|
||||
t.Fatal("plain HTTP request leaked into TLS channel")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("plain HTTP connection never reached plain channel")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@@ -20,14 +21,17 @@ import (
|
||||
type Config struct {
|
||||
// ListenAddr is the TCP address the main listener binds. Required.
|
||||
ListenAddr string
|
||||
// ID identifies this proxy instance to management. Empty value lets
|
||||
// New generate a timestamped default.
|
||||
// ID identifies this proxy instance to management. Empty values are
|
||||
// replaced with a timestamped default at Server.Start time (see
|
||||
// initDefaults), not in New.
|
||||
ID string
|
||||
// Logger is the logrus logger used everywhere. Empty value falls back
|
||||
// to log.StandardLogger().
|
||||
// Logger is the logrus logger used everywhere. Empty values fall
|
||||
// back to log.StandardLogger() at Server.Start time (see
|
||||
// initDefaults), not in New.
|
||||
Logger *log.Logger
|
||||
// Version is the build version string reported to management. Empty
|
||||
// becomes "dev".
|
||||
// values are replaced with "dev" at Server.Start time (see
|
||||
// initDefaults), not in New.
|
||||
Version string
|
||||
// ProxyURL is the public address operators use to reach this proxy.
|
||||
ProxyURL string
|
||||
@@ -125,8 +129,9 @@ type Config struct {
|
||||
// bound — call Start to bring the proxy up. Returning a fully-formed
|
||||
// Server keeps the standalone code path (which still constructs Server
|
||||
// directly) byte-for-byte equivalent.
|
||||
func New(cfg Config) *Server {
|
||||
func New(ctx context.Context, cfg Config) *Server {
|
||||
return &Server{
|
||||
ctx: ctx,
|
||||
ListenAddr: cfg.ListenAddr,
|
||||
ID: cfg.ID,
|
||||
Logger: cfg.Logger,
|
||||
|
||||
@@ -73,7 +73,7 @@ func benchServerWithLatency(b *testing.B, createPeerDelay, statusDelay time.Dura
|
||||
statusUpdateDelay: statusDelay,
|
||||
}
|
||||
|
||||
nb := roundtrip.NewNetBird("bench-proxy", "bench.test",
|
||||
nb := roundtrip.NewNetBird(b.Context(), "bench-proxy", "bench.test",
|
||||
roundtrip.ClientConfig{MgmtAddr: "http://bench.test:9999"},
|
||||
logger, nil, mgmtClient)
|
||||
|
||||
|
||||
@@ -75,6 +75,7 @@ type portRouter struct {
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
mgmtClient proto.ProxyServiceClient
|
||||
proxy *proxy.ReverseProxy
|
||||
netbird *roundtrip.NetBird
|
||||
@@ -281,7 +282,7 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac
|
||||
}
|
||||
|
||||
// inboundListenerProto resolves the per-account inbound listener state for
|
||||
// the SendStatusUpdate payload. Returns nil when --private-inbound is off
|
||||
// the SendStatusUpdate payload. Returns nil when --private is off
|
||||
// or the account has no live listener so management treats the field as
|
||||
// absent.
|
||||
func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener {
|
||||
@@ -528,10 +529,10 @@ func (s *Server) initManagementClient() error {
|
||||
}
|
||||
|
||||
// initNetBirdClient builds the multi-tenant embedded NetBird client used
|
||||
// for outbound RoundTripping and (when --private-inbound is on) per-account
|
||||
// for outbound RoundTripping and (when --private is on) per-account
|
||||
// inbound listeners.
|
||||
func (s *Server) initNetBirdClient() {
|
||||
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
|
||||
s.netbird = roundtrip.NewNetBird(s.ctx, s.ID, s.ProxyURL, roundtrip.ClientConfig{
|
||||
MgmtAddr: s.ManagementAddress,
|
||||
WGPort: s.WireguardPort,
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
|
||||
@@ -64,7 +64,7 @@ func quietLifecycleLogger() *log.Logger {
|
||||
}
|
||||
|
||||
func TestStopBeforeStartIsNoOp(t *testing.T) {
|
||||
srv := New(Config{Logger: quietLifecycleLogger()})
|
||||
srv := New(t.Context(), Config{Logger: quietLifecycleLogger()})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
@@ -77,7 +77,7 @@ func TestStopBeforeStartIsNoOp(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartFailsWithoutManagement(t *testing.T) {
|
||||
srv := New(Config{
|
||||
srv := New(t.Context(), Config{
|
||||
Logger: quietLifecycleLogger(),
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
ManagementAddress: "://broken-url",
|
||||
@@ -137,7 +137,7 @@ func TestRecordRunErrPreservesFirstFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStopSkipsShutdownWhenNeverStarted(t *testing.T) {
|
||||
srv := New(Config{Logger: quietLifecycleLogger()})
|
||||
srv := New(t.Context(), Config{Logger: quietLifecycleLogger()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
@@ -2,6 +2,7 @@ package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -33,6 +34,12 @@ func (a *ReverseProxyClustersAPI) List(ctx context.Context) ([]api.ProxyCluster,
|
||||
// NetBird cannot be deleted via this endpoint; the server returns 404 / 400
|
||||
// for cluster addresses the account does not own.
|
||||
func (a *ReverseProxyClustersAPI) Delete(ctx context.Context, clusterAddress string) error {
|
||||
// Guard against the empty input: url.PathEscape("") returns "" which
|
||||
// would collapse the request URL onto the collection endpoint and
|
||||
// silently delete nothing (or 405 depending on routing).
|
||||
if clusterAddress == "" {
|
||||
return errors.New("clusterAddress is required")
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/clusters/"+url.PathEscape(clusterAddress), nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -88,3 +88,17 @@ func TestReverseProxyClusters_Delete_Err(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestReverseProxyClusters_Delete_EmptyAddress guards against an empty
|
||||
// clusterAddress reaching the wire — that would collapse the URL onto
|
||||
// the collection endpoint instead of a specific cluster. The client
|
||||
// must short-circuit with a typed error before any request is issued.
|
||||
func TestReverseProxyClusters_Delete_EmptyAddress(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/clusters/", func(http.ResponseWriter, *http.Request) {
|
||||
t.Fatal("empty clusterAddress must be rejected client-side; no request should reach the server")
|
||||
})
|
||||
err := c.ReverseProxyClusters.Delete(context.Background(), "")
|
||||
assert.Error(t, err, "empty clusterAddress must surface as an error")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/url"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -61,6 +62,12 @@ func (a *ReverseProxyTokensAPI) Create(ctx context.Context, request api.ProxyTok
|
||||
// credentials existed; the plain secret can no longer authenticate any
|
||||
// new proxy registration.
|
||||
func (a *ReverseProxyTokensAPI) Delete(ctx context.Context, tokenID string) error {
|
||||
// Guard against the empty input: url.PathEscape("") returns "" which
|
||||
// would collapse the request URL onto the collection endpoint and
|
||||
// silently delete nothing (or 405 depending on routing).
|
||||
if tokenID == "" {
|
||||
return errors.New("tokenID is required")
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/proxy-tokens/"+url.PathEscape(tokenID), nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -129,3 +129,16 @@ func TestReverseProxyTokens_Delete_Err(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestReverseProxyTokens_Delete_EmptyID guards against an empty tokenID
|
||||
// reaching the wire — url.PathEscape("") would collapse the URL onto
|
||||
// the collection endpoint.
|
||||
func TestReverseProxyTokens_Delete_EmptyID(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/proxy-tokens/", func(http.ResponseWriter, *http.Request) {
|
||||
t.Fatal("empty tokenID must be rejected client-side; no request should reach the server")
|
||||
})
|
||||
err := c.ReverseProxyTokens.Delete(context.Background(), "")
|
||||
assert.Error(t, err, "empty tokenID must surface as an error")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3086,6 +3086,24 @@ components:
|
||||
- enabled
|
||||
- auth
|
||||
- meta
|
||||
allOf:
|
||||
# When private=true, access_groups must be present and non-empty,
|
||||
# and the service mode must be "http". The bearer-auth mutex is
|
||||
# enforced at the service-validation layer
|
||||
# (validatePrivateRequirements) because it sits in a nested
|
||||
# ServiceAuthConfig and isn't cleanly expressible here.
|
||||
- if:
|
||||
required: [private]
|
||||
properties:
|
||||
private:
|
||||
const: true
|
||||
then:
|
||||
required: [access_groups]
|
||||
properties:
|
||||
access_groups:
|
||||
minItems: 1
|
||||
mode:
|
||||
const: http
|
||||
ServiceMeta:
|
||||
type: object
|
||||
properties:
|
||||
@@ -3173,6 +3191,23 @@ components:
|
||||
- name
|
||||
- domain
|
||||
- enabled
|
||||
allOf:
|
||||
# Mirror of the Service conditional: when private=true the
|
||||
# request must carry a non-empty access_groups list and the
|
||||
# mode must be "http". The bearer-auth mutex is enforced at the
|
||||
# service-validation layer (validatePrivateRequirements).
|
||||
- if:
|
||||
required: [private]
|
||||
properties:
|
||||
private:
|
||||
const: true
|
||||
then:
|
||||
required: [access_groups]
|
||||
properties:
|
||||
access_groups:
|
||||
minItems: 1
|
||||
mode:
|
||||
const: http
|
||||
ServiceTargetOptions:
|
||||
type: object
|
||||
properties:
|
||||
|
||||
@@ -237,7 +237,7 @@ message SendStatusUpdateRequest {
|
||||
bool certificate_issued = 4;
|
||||
optional string error_message = 5;
|
||||
// Per-account inbound listener state for the account that owns
|
||||
// service_id. Populated only when --private-inbound is enabled and the
|
||||
// service_id. Populated only when --private is enabled and the
|
||||
// embedded client for the account is up. Field numbers >=50 reserved
|
||||
// for observability extensions.
|
||||
optional ProxyInboundListener inbound_listener = 50;
|
||||
|
||||
@@ -2,19 +2,75 @@ package version
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
)
|
||||
|
||||
// DevelopmentVersion is the value of NetbirdVersion() for non-release builds.
|
||||
// Wire-format consumers (management server, dashboard) match against this
|
||||
// string, so it must not change without coordinating those consumers.
|
||||
const DevelopmentVersion = "development"
|
||||
|
||||
// will be replaced with the release version when using goreleaser
|
||||
var version = "development"
|
||||
var version = DevelopmentVersion
|
||||
|
||||
var (
|
||||
VersionRegexp = regexp.MustCompile("^" + v.VersionRegexpRaw + "$")
|
||||
SemverRegexp = regexp.MustCompile("^" + v.SemverRegexpRaw + "$")
|
||||
)
|
||||
|
||||
// NetbirdVersion returns the Netbird version
|
||||
// NetbirdVersion returns the Netbird version. For non-release builds the
|
||||
// value is the literal DevelopmentVersion constant; the VCS revision is
|
||||
// exposed separately via NetbirdCommit so the wire format stays stable.
|
||||
func NetbirdVersion() string {
|
||||
return version
|
||||
}
|
||||
|
||||
// NetbirdCommit returns the VCS revision (truncated to 12 chars) of the
|
||||
// build, with a "-dirty" suffix when the working tree was modified.
|
||||
// Returns an empty string when no build info is embedded (e.g. release
|
||||
// builds compiled by goreleaser without -buildvcs).
|
||||
func NetbirdCommit() string {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
var revision string
|
||||
var modified bool
|
||||
for _, s := range info.Settings {
|
||||
switch s.Key {
|
||||
case "vcs.revision":
|
||||
revision = s.Value
|
||||
case "vcs.modified":
|
||||
modified = s.Value == "true"
|
||||
}
|
||||
}
|
||||
|
||||
if revision == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(revision) > 12 {
|
||||
revision = revision[:12]
|
||||
}
|
||||
|
||||
if modified {
|
||||
revision += "-dirty"
|
||||
}
|
||||
return revision
|
||||
}
|
||||
|
||||
// IsDevelopmentVersion reports whether the given version string identifies
|
||||
// a non-release / development build. It is the single source of truth for
|
||||
// "is this a dev build" checks across the codebase; use it instead of
|
||||
// comparing against the "development" literal or ad-hoc substring checks.
|
||||
//
|
||||
// Matches the bare DevelopmentVersion constant as well as any future
|
||||
// extension such as "development-<sha>" or "development-<sha>-dirty",
|
||||
// while excluding tagged prereleases like "v0.31.1-dev".
|
||||
func IsDevelopmentVersion(v string) bool {
|
||||
return strings.HasPrefix(v, DevelopmentVersion)
|
||||
}
|
||||
|
||||
26
version/version_test.go
Normal file
26
version/version_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package version
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsDevelopmentVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
version string
|
||||
want bool
|
||||
}{
|
||||
{"development", true},
|
||||
{"development-0823f3ff9ab1", true},
|
||||
{"development-0823f3ff9ab1-dirty", true},
|
||||
{"0.50.0", false},
|
||||
{"v0.31.1-dev", false},
|
||||
{"1.0.0-dev", false},
|
||||
{"dev", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.version, func(t *testing.T) {
|
||||
if got := IsDevelopmentVersion(tt.version); got != tt.want {
|
||||
t.Errorf("IsDevelopmentVersion(%q) = %v, want %v", tt.version, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user