Compare commits

...

9 Commits

Author SHA1 Message Date
mlsmaycon
75c8fa78e2 Merge branch 'main' into follow-up-private-services 2026-05-27 16:00:40 +02:00
mlsmaycon
4f7c73369b fix(proxy): background ctx for already-started AddPeer notification
The earlier ctx fix covered the async runClientStartup path but missed
the synchronous branch: when a service is added to an already-started
client, AddPeer called NotifyStatus with the caller's request-scoped
ctx. A cancelled request/stream could drop the connected notification
to management. Use context.Background() here too, matching
notifyClientReady.

Extends TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus to
pass a pre-cancelled caller ctx and assert the notification still ran
on a non-cancelled context.
2026-05-26 22:52:11 +02:00
mlsmaycon
924be2116b chore(api,ci,docs,test): private-service schema, proto-check, fixups
Non-functional cleanups and contract/CI hardening around the
private-service work:

API schema (openapi.yml):
- Require a non-empty access_groups and mode=http when private=true,
  on both Service and ServiceRequest, mirroring
  validatePrivateRequirements. mode stays optional-but-constrained
  (empty defaults to http server-side), matching runtime.

CI (proto-version-check.yml):
- Cover renamed .pb.go files (read base via previous_filename).
- Match protoc-gen-go-grpc version headers (optional "- " prefix and
  -gen-go-grpc suffix) so grpc-generated files are in scope.

Docs / comments:
- Reword Config field docs to say defaults are applied at Server.Start
  (initDefaults), not New.
- Rename the obsolete --private-inbound flag to --private across
  comments and the proto doc.

Pre-existing test fixups surfaced by review:
- Repair the integration-tagged validate_session_test.go (SignToken
  signature growth + new Manager interface methods).
- Fix the CI-skip boolean precedence so Windows isn't skipped
  unconditionally.
- Guard the router.HTTPListener type assertion with comma-ok.
2026-05-26 22:00:08 +02:00
mlsmaycon
43c0cb1dc2 fix(rest): reject empty Delete path params in reverse-proxy clients
ReverseProxyClustersAPI.Delete and ReverseProxyTokensAPI.Delete passed
the path parameter into url.PathEscape without an empty check.
PathEscape("") returns "" which collapses the request onto the
collection endpoint ("/api/reverse-proxies/clusters/" /
"/api/reverse-proxies/proxy-tokens/"), so a caller bug delete with no
id reached a routable URL with surprising semantics (typically 405).

Short-circuit with a typed error before the request is built. Tests
mount a handler on the collection path that fails the test if hit, so
the regression is impossible to reintroduce silently.
2026-05-26 21:59:34 +02:00
mlsmaycon
86e24a622d fix(client): include offlinePeers in PeerStateByIP lookup
ReplaceOfflinePeers moves peers into d.offlinePeers but PeerStateByIP
only scanned d.peers. Callers (the local DNS filter via
localPeerConnectivity, embed.Client.IdentityForIP used by the
proxy's tunnel-peer validator) were treating known-but-offline peers
as unknown, which:

- causes the DNS filter to keep returning records pointing at peers
  that have no live tunnel, AND
- makes the proxy's local-roster check deny a request from such a
  peer rather than letting the cached management RPC carry the
  authorisation decision.

Search both slices in PeerStateByIP. Adds a unit test for the IPv4
and IPv6 offline-match paths.
2026-05-26 21:59:34 +02:00
mlsmaycon
af416c656b fix(management): private-service validation + tunnel-IP lookup semantics
- Require an explicit port for L4 cluster targets. validateL4Target
  exempted TargetTypeCluster from the port check, but buildPathMappings
  serializes every L4 target via net.JoinHostPort(host, port) — port=0
  shipped a ":0" upstream. Cluster targets use the same Host/Port
  fields, so the same requirement applies.
- GetPeerByIP returns NotFound on a tunnel-IP miss instead of mapping
  every error to Internal. The proxy's ValidateTunnelPeer probes IPs
  that legitimately aren't in the roster; the miss is expected and now
  distinguishable from a real store failure.
- Thread ctx into getClusterCapability's gorm query so a cancelled
  request doesn't keep the store busy.

Tests updated for the L4-cluster port requirement and the GetPeerByIP
NotFound path.
2026-05-26 21:59:22 +02:00
mlsmaycon
6600f0d45f feat(proxy): short-circuit peer-own-target loops with 421
When a peer that hosts the target of a private service dials its own
service URL the request was being looped through the proxy and back
over WireGuard to the same peer — twice the WG round-trip for no
benefit, with no signal to the caller that something was wrong.

Add isSelfTargetLoop to ReverseProxy.ServeHTTP: when the request
arrived on the per-account overlay listener (IsOverlayOrigin) and the
source tunnel IP matches the target host, refuse the request with 421
Misdirected Request and a body pointing the operator at the backend
directly.

The gate is scoped to overlay origin so requests on the public
listener that happen to share a source IP with the target host are
forwarded normally.
2026-05-26 21:58:58 +02:00
mlsmaycon
878344088f fix(proxy): harden inbound listener resource + startup-ctx handling
Three correctness fixes on the per-account inbound path, with tests:

- Close the logrus ErrorLog PipeWriter on tearDown. WriterLevel hands
  back an *io.PipeWriter backed by a pipe + scanner goroutine that the
  caller owns; the two writers per account (https + plain) were never
  closed, leaking the pipe and goroutine on every teardown.
- Run the post-Start hooks on context.Background(). runClientStartup
  is launched in a goroutine from AddPeer and was inheriting the
  caller's request-scoped ctx, so a cancelled request could abort the
  inbound bring-up or fail the management status notification. The
  tail is split into notifyClientReady so the contract is testable.

Tests cover the PipeWriter close behaviour and assert the readyHandler
+ NotifyStatus calls receive a non-cancelled background context.
2026-05-26 21:58:52 +02:00
mlsmaycon
e09d51c1a8 fix(proxy): gate tunnel-peer fast-path on inbound listener marker
forwardWithTunnelPeer previously accepted any RFC1918 / ULA / CGNAT
source IP, so a public client whose address happened to fall in those
ranges could bypass the configured operator auth scheme by colliding
with a known tunnel IP. The fast-path is now gated on
TunnelLookupFromContext(r.Context()) being present — that context value
is attached only by the per-account inbound (overlay) listener, so the
host-facing listener never enters this branch.

Tests updated to reflect the new requirement: requests that don't
carry the inbound marker now fall through to the regular auth flow.
2026-05-26 21:58:23 +02:00
28 changed files with 692 additions and 89 deletions

View File

@@ -20,15 +20,30 @@ jobs:
per_page: 100,
});
const modifiedPbFiles = files.filter(
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
);
if (modifiedPbFiles.length === 0) {
console.log('No modified .pb.go files to check');
// Cover renamed .pb.go files in addition to plain edits.
// Renamed entries land under the new path with previous_filename
// pointing at the base-side name, so we read the base content
// from the old path when present.
const changedPbFiles = files
.filter(f => (f.status === 'modified' || f.status === 'renamed')
&& f.filename.endsWith('.pb.go'))
.map(f => ({
headPath: f.filename,
basePath: f.previous_filename || f.filename,
}));
if (changedPbFiles.length === 0) {
console.log('No modified or renamed .pb.go files to check');
return;
}
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
// Matches the generator version headers protoc writes at the top
// of generated files:
// // protoc v3.21.12
// // protoc-gen-go v1.26.0
// // - protoc-gen-go-grpc v1.6.1 (grpc files prefix with "- ")
// The optional "- " prefix and the optional -gen-go / -gen-go-grpc
// suffixes keep the *_grpc.pb.go headers in scope.
const versionPattern = /^\s*\/\/\s+(?:-\s+)?protoc(?:-gen-go(?:-grpc)?)?\s+v[\d.]+/;
const baseSha = context.payload.pull_request.base.sha;
const headSha = context.payload.pull_request.head.sha;
@@ -55,20 +70,22 @@ jobs:
}
const violations = [];
for (const file of modifiedPbFiles) {
for (const file of changedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
getVersionHeader(file.basePath, baseSha),
getVersionHeader(file.headPath, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
`Skipping ${file.headPath}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
violations.push({
file: file.filename,
file: file.basePath === file.headPath
? file.headPath
: `${file.basePath} → ${file.headPath}`,
base: base.lines,
head: head.lines,
});

View File

@@ -310,8 +310,12 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Returns the
// zero State and false when no peer matches or the input is empty.
// address so dual-stack peers are reachable on either family. Searches
// both d.peers and d.offlinePeers — peers that have been moved into
// the offline slice by ReplaceOfflinePeers are still part of the
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
// need to recognise them rather than treating them as unknown. Returns
// the zero State and false when no peer matches or the input is empty.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
@@ -324,6 +328,11 @@ func (d *Status) PeerStateByIP(ip string) (State, bool) {
return state, true
}
}
for _, state := range d.offlinePeers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
}
return State{}, false
}

View File

@@ -90,6 +90,28 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
// been moved into the offline slice via ReplaceOfflinePeers. Callers
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
// rather than unknown — otherwise authentication / DNS filtering treats
// known-but-offline peers as foreign IPs.
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
status.ReplaceOfflinePeers([]State{
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
})
state, ok := status.PeerStateByIP("100.64.0.20")
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
state, ok = status.PeerStateByIP("fd00::20")
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"

View File

@@ -932,7 +932,11 @@ func (s *Service) validateL4Target(target *Target) error {
if target.TargetId == "" {
return errors.New("target_id is required for L4 services")
}
if target.TargetType != TargetTypeCluster && target.Port == 0 {
// Cluster targets resolve their upstream host:port from the target's
// own Host/Port fields just like the other L4 types — buildPathMappings
// emits net.JoinHostPort(target.Host, target.Port) for every L4
// target, so allowing port=0 here would let ":0" reach the proxy.
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
switch target.TargetType {

View File

@@ -1176,7 +1176,12 @@ func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
}
func TestValidate_L4ClusterTarget(t *testing.T) {
// TestValidate_L4ClusterTarget_RequiresPort confirms that an L4 cluster
// target without an explicit port is rejected. buildPathMappings emits
// net.JoinHostPort(target.Host, target.Port) for every L4 target — so
// allowing port=0 would let the proxy ship ":0" upstreams. The port
// requirement is the same as every other L4 target type.
func TestValidate_L4ClusterTarget_RequiresPort(t *testing.T) {
rp := validProxy()
rp.Mode = ModeTCP
rp.ListenPort = 9000
@@ -1186,7 +1191,12 @@ func TestValidate_L4ClusterTarget(t *testing.T) {
Protocol: "tcp",
Enabled: true,
}}
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
assert.ErrorContains(t, rp.Validate(), "port is required",
"L4 cluster target must require an explicit port like other L4 target types")
rp.Targets[0].Port = 5432
rp.Targets[0].Host = "db.lan"
require.NoError(t, rp.Validate(), "L4 cluster target with host:port must validate")
}
func TestService_Copy_RoundtripsPrivate(t *testing.T) {

View File

@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
token, err := sessionkey.SignToken(privKeyB64, userID, "", domain, auth.MethodOIDC, nil, nil, time.Hour)
require.NoError(t, err)
return token
}
@@ -394,6 +394,10 @@ func (m *testValidateSessionProxyManager) ClusterSupportsCrowdSec(_ context.Cont
return nil
}
func (m *testValidateSessionProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store
}
@@ -401,3 +405,24 @@ type testValidateSessionUsersManager struct {
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}
func (m *testValidateSessionUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, nil, err
}
if len(user.AutoGroups) == 0 {
return user, nil, nil
}
groupsMap, err := m.store.GetGroupsByIDs(ctx, store.LockingStrengthNone, user.AccountID, user.AutoGroups)
if err != nil {
return nil, nil, err
}
groups := make([]*types.Group, 0, len(user.AutoGroups))
for _, id := range user.AutoGroups {
if g, ok := groupsMap[id]; ok && g != nil {
groups = append(groups, g)
}
}
return user, groups, nil
}

View File

@@ -4734,7 +4734,13 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
result := tx.
Take(&peer, fmt.Sprintf("account_id = ? AND %s = ?", column), accountID, jsonValue)
if result.Error != nil {
// no logging here
// A tunnel-IP miss is an expected outcome (e.g. the proxy's
// ValidateTunnelPeer probing an address that isn't in the
// account roster); surface it as NotFound so callers can tell
// it apart from a real store failure.
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer with ip %s not found", ip.String())
}
return nil, status.Errorf(status.Internal, "failed to get peer from store")
}
@@ -5962,6 +5968,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
}
err := s.db.
WithContext(ctx).
Model(&proxy.Proxy{}).
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").

View File

@@ -13,7 +13,7 @@ import (
)
func TestSqlStore_GetAccount_PrivateServiceRoundtrip(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
if os.Getenv("CI") == "true" && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
t.Skip("skip CI tests on darwin and windows")
}

View File

@@ -491,6 +491,27 @@ func Test_GetAccount(t *testing.T) {
})
}
// TestSqlStore_GetPeerByIP_NotFound pins the not-found semantics the
// proxy's ValidateTunnelPeer relies on: a tunnel-IP that isn't in the
// account roster must surface as a NotFound error (not a generic
// Internal) so callers can distinguish an expected miss from a real
// store failure. A known IP still resolves.
func TestSqlStore_GetPeerByIP_NotFound(t *testing.T) {
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peer, err := store.GetPeerByIP(context.Background(), LockingStrengthNone, accountID, net.ParseIP("192.168.0.0"))
require.NoError(t, err, "known tunnel IP must resolve")
require.NotNil(t, peer)
_, err = store.GetPeerByIP(context.Background(), LockingStrengthNone, accountID, net.ParseIP("100.65.0.99"))
require.Error(t, err, "unknown tunnel IP must error")
parsedErr, ok := status.FromError(err)
require.True(t, ok, "error must be a status error")
require.Equal(t, status.NotFound, parsedErr.Type(), "tunnel-IP miss must be NotFound, not Internal")
})
}
func TestSqlStore_SavePeer(t *testing.T) {
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
stdlog "log"
"net"
"net/http"
@@ -42,7 +43,7 @@ const privateInboundPortHTTPS = 443
const privateInboundPortHTTP = 80
// inboundManager wires per-account inbound listeners into the proxy
// pipeline when --private-inbound is enabled. When disabled the manager
// pipeline when --private is enabled. When disabled the manager
// is nil and every method on *Server that touches it short-circuits.
type inboundManager struct {
logger *log.Logger
@@ -55,15 +56,18 @@ type inboundManager struct {
}
// inboundEntry owns the listeners, router and HTTP servers for a single
// account's embedded netstack.
// account's embedded netstack. errorLogWriters retain the logrus pipe
// writers backing each http.Server's ErrorLog so tearDown can close
// them — otherwise the pipe + its scanner goroutine leak per account.
type inboundEntry struct {
router *nbtcp.Router
tlsListener net.Listener
plainListener net.Listener
httpsServer *http.Server
httpServer *http.Server
cancel context.CancelFunc
wg sync.WaitGroup
router *nbtcp.Router
tlsListener net.Listener
plainListener net.Listener
httpsServer *http.Server
httpServer *http.Server
errorLogWriters []*io.PipeWriter
cancel context.CancelFunc
wg sync.WaitGroup
}
// pendingInboundRoute holds a route that arrived before the account's
@@ -147,30 +151,34 @@ func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID,
return types.WithOverlayOrigin(ctx)
}
httpsErrLog, httpsErrW := newInboundErrorLog(m.logger, "https", accountID)
httpErrLog, httpErrW := newInboundErrorLog(m.logger, "http", accountID)
httpsServer := &http.Server{
Handler: scopedHandler,
TLSConfig: m.tlsConfig,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "https", accountID),
ErrorLog: httpsErrLog,
ConnContext: markOverlayOrigin,
}
httpServer := &http.Server{
Handler: scopedHandler,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "http", accountID),
ErrorLog: httpErrLog,
ConnContext: markOverlayOrigin,
}
runCtx, cancel := context.WithCancel(ctx)
entry := &inboundEntry{
router: router,
tlsListener: tlsListener,
plainListener: plainListener,
httpsServer: httpsServer,
httpServer: httpServer,
cancel: cancel,
router: router,
tlsListener: tlsListener,
plainListener: plainListener,
httpsServer: httpsServer,
httpServer: httpServer,
errorLogWriters: []*io.PipeWriter{httpsErrW, httpErrW},
cancel: cancel,
}
entry.wg.Add(1)
@@ -237,6 +245,14 @@ func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry
m.logger.Debugf("close per-account plain listener: %v", err)
}
entry.wg.Wait()
// Close the ErrorLog pipes only after the http.Servers have fully
// stopped so any straggling stdlib write doesn't race with the
// close. Each writer also tears down the logrus scanner goroutine.
for _, w := range entry.errorLogWriters {
if err := w.Close(); err != nil {
m.logger.Debugf("close per-account inbound error log writer: %v", err)
}
}
}
// AddRoute records an SNI/host route on the account's per-account router.
@@ -374,7 +390,7 @@ func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListene
}
// Snapshot returns the inbound listener state for every account that has
// a live listener at call time. Empty when --private-inbound is off or
// a live listener at call time. Empty when --private is off or
// no accounts have come up yet.
func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo {
if m == nil {
@@ -497,7 +513,7 @@ func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc {
// peerstore lookup to every request's context before delegating to next.
// Calling on the host-level listener is a no-op because that path never
// installs this wrapper, so the existing behaviour stays byte-for-byte
// identical when --private-inbound is off or the request didn't arrive
// identical when --private is off or the request didn't arrive
// on a per-account listener.
func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler {
if lookup == nil {
@@ -538,10 +554,14 @@ func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.Inboun
}
// newInboundErrorLog routes a per-account http.Server's stdlib error
// stream through logrus at warn level.
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger {
return stdlog.New(logger.WithFields(log.Fields{
// stream through logrus at warn level. The returned PipeWriter must be
// closed by the caller (tearDown) once the http.Server has shut down —
// otherwise the pipe and its scanner goroutine leak per account, see
// logrus.Entry.WriterLevel.
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) (*stdlog.Logger, *io.PipeWriter) {
w := logger.WithFields(log.Fields{
"inbound-http": scheme,
"account_id": accountID,
}).WriterLevel(log.WarnLevel), "", 0)
}).WriterLevel(log.WarnLevel)
return stdlog.New(w, "", 0), w
}

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
@@ -139,7 +140,7 @@ func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) {
// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability
// bit reported upstream tracks --private exclusively. The previous
// --private-inbound flag has been folded into --private.
// --private flag has been folded into --private.
func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) {
tests := []struct {
name string
@@ -318,7 +319,7 @@ func TestInboundManager_ListenerInfo(t *testing.T) {
}
// TestInboundManager_NilManagerSafe ensures the observability accessors
// are safe to call when --private-inbound is off (nil manager).
// are safe to call when --private is off (nil manager).
func TestInboundManager_NilManagerSafe(t *testing.T) {
var mgr *inboundManager
_, ok := mgr.ListenerInfo("anything")
@@ -482,6 +483,38 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config {
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec
}
// TestNewInboundErrorLog_WriterIsCloseable guards the close path on the
// logrus PipeWriter that backs each per-account http.Server's ErrorLog.
// logrus.Entry.WriterLevel returns an *io.PipeWriter that owns a pipe +
// scanner goroutine; the caller must Close() it on teardown or the
// resources leak per account. The contract is verified two ways:
//
// - the constructor returns a non-nil writer the caller can keep,
// - writing to the writer after Close() fails with io.ErrClosedPipe,
// which is the only externally observable sign that Close was wired.
//
// A leaking refactor (forgetting to thread the writer to tearDown, or
// dropping the Close call) would still pass this test individually but
// fail an integration goleak check; this unit test is the cheap first
// line of defence.
func TestNewInboundErrorLog_WriterIsCloseable(t *testing.T) {
logger := quietLogger()
stdLog, writer := newInboundErrorLog(logger, "https", types.AccountID("acct-1"))
require.NotNil(t, stdLog, "newInboundErrorLog must return a non-nil *log.Logger")
require.NotNil(t, writer, "newInboundErrorLog must return the underlying PipeWriter so tearDown can Close it")
// First Close succeeds.
require.NoError(t, writer.Close(), "PipeWriter.Close should succeed the first time")
// After Close, the writer must refuse new writes — that's the only
// behavioural signal that the pipe (and its scanner goroutine) has
// shut down.
_, err := writer.Write([]byte("post-close write\n"))
require.ErrorIs(t, err, io.ErrClosedPipe,
"writes after Close must surface io.ErrClosedPipe so callers know the writer is gone")
}
// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for
// 127.0.0.1 — only used by tests that need a working TLS handshake.
var testCertPEM = []byte(`-----BEGIN CERTIFICATE-----

View File

@@ -346,13 +346,15 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
// management unreachable, peer unknown, user not in group) returns false so
// the caller falls back to the existing OIDC scheme dispatch.
//
// Phase 3 adds a local-first short-circuit: when the request arrived on a
// per-account inbound listener the context carries a peerstore lookup
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
// roster the proxy denies fast without calling management. If the lookup
// confirms a known peer the RPC still runs for the user-identity tail
// (UserID + group access), but its result is cached for tunnelCacheTTL so
// repeat requests skip management entirely.
// The fast-path is gated on TunnelLookupFromContext(r.Context()) being
// present — that context value is attached only by the per-account
// inbound (overlay) listener. The host listener never sets it, so a
// public client whose source IP happens to fall inside an RFC1918 / ULA
// / CGNAT range can't impersonate a mesh peer by colliding with a
// tunnel-IP. Once we know the request arrived over WireGuard the
// per-account peerstore lookup is consulted: a miss denies fast (no
// management round-trip), a hit gates the cached ValidateTunnelPeer RPC
// that mints the session JWT.
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
if mw.sessionValidator == nil {
return false
@@ -361,18 +363,24 @@ func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Reque
if !clientIP.IsValid() {
return false
}
// Anti-spoof: only honour the tunnel-peer fast-path on requests that
// were stamped by an overlay listener. Without that marker an
// attacker could send a request from a colliding RFC1918 / CGNAT
// source on the public listener and bypass operator auth.
lookup := TunnelLookupFromContext(r.Context())
if lookup == nil {
return false
}
if !isTunnelSourceIP(clientIP) {
return false
}
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
if _, ok := lookup(clientIP); !ok {
mw.logger.WithFields(log.Fields{
"host": host,
"remote": clientIP,
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
return false
}
if _, ok := lookup(clientIP); !ok {
mw.logger.WithFields(log.Fields{
"host": host,
"remote": clientIP,
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
return false
}
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{

View File

@@ -1227,3 +1227,93 @@ func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
}
// stubTunnelValidator records ValidateTunnelPeer calls so a test can
// assert whether the fast-path reached management.
type stubTunnelValidator struct {
called bool
resp *proto.ValidateTunnelPeerResponse
}
func (s *stubTunnelValidator) ValidateSession(context.Context, *proto.ValidateSessionRequest, ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
return nil, errors.New("not used in this test")
}
func (s *stubTunnelValidator) ValidateTunnelPeer(context.Context, *proto.ValidateTunnelPeerRequest, ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
s.called = true
return s.resp, nil
}
// TestProtect_TunnelPeerFastPath_RequiresInboundMarker guards the
// anti-spoof gate: a request with an RFC1918 source IP arriving on the
// public listener (no TunnelLookupFromContext attached) must not be
// allowed to take the tunnel-peer fast-path. Without this gate a public
// client whose source IP happens to fall inside an RFC1918 range could
// bypass the configured auth scheme by colliding with a known tunnel
// IP.
func TestProtect_TunnelPeerFastPath_RequiresInboundMarker(t *testing.T) {
validator := &stubTunnelValidator{
resp: &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "should-not-be-used",
UserId: "user-1",
},
}
mw := NewMiddleware(log.StandardLogger(), validator, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
// Request from an RFC1918 source IP on the public listener — no
// TunnelLookupFromContext attached. The fast-path must reject this
// and fall through to the PIN scheme (which renders 401 on plain
// HTTP for a non-authenticated request).
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "100.64.0.5:5000"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, validator.called,
"ValidateTunnelPeer must not be invoked when the request lacks the inbound TunnelLookup marker")
assert.Equal(t, http.StatusUnauthorized, rec.Code,
"without the inbound marker the request must fall through to the operator auth scheme")
}
// TestProtect_TunnelPeerFastPath_TakesPathWithInboundMarker verifies
// the positive side: a request marked as overlay-origin (carrying the
// TunnelLookup context value) and matching a tunnel-IP range does take
// the fast-path and reach management.
func TestProtect_TunnelPeerFastPath_TakesPathWithInboundMarker(t *testing.T) {
validator := &stubTunnelValidator{
resp: &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "tunnel-session-token",
UserId: "user-1",
},
}
mw := NewMiddleware(log.StandardLogger(), validator, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "100.64.0.5:5000"
req = req.WithContext(WithTunnelLookup(req.Context(), lookup))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.True(t, validator.called,
"ValidateTunnelPeer must run when the request carries the inbound TunnelLookup marker")
assert.Equal(t, http.StatusOK, rec.Code,
"a successful tunnel-peer validation must forward to the next handler")
}

View File

@@ -101,7 +101,10 @@ func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) {
w, r := newTunnelRequest("100.64.0.10:55555")
cd := proxy.NewCapturedData("")
r = r.WithContext(proxy.WithCapturedData(r.Context(), cd))
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
r = r.WithContext(proxy.WithCapturedData(WithTunnelLookup(r.Context(), lookup), cd))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
@@ -148,9 +151,13 @@ func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) {
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer")
}
// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing
// behaviour stays intact on the host-level listener (no lookup attached).
func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
// TestForwardWithTunnelPeer_NoLookupRefusesFastPath guards the
// anti-spoof gate: requests that didn't arrive on the per-account
// inbound listener (no TunnelLookup attached) must never reach
// management's ValidateTunnelPeer, even when the source IP looks like
// a tunnel address. A colliding RFC1918 / CGNAT source on the public
// listener would otherwise impersonate a mesh peer.
func TestForwardWithTunnelPeer_NoLookupRefusesFastPath(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
@@ -165,9 +172,9 @@ func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.True(t, handled, "host-level path forwards on positive RPC result")
assert.True(t, called, "next handler runs on host-level success")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)")
assert.False(t, handled, "fast-path must refuse without the inbound marker")
assert.False(t, called, "next handler must not run")
assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must not be invoked without the inbound marker")
}
// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC
@@ -201,8 +208,13 @@ func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) {
}
mw := newTunnelMiddleware(t, validator)
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
for i := 0; i < 4; i++ {
w, r := newTunnelRequest("100.64.0.10:55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
@@ -226,11 +238,21 @@ func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) {
require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false))
require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false))
// The fast-path requires the inbound-listener marker on the context.
// The peerstore lookup itself is account-agnostic at this level
// (one TunnelLookupFunc per account is attached by inbound.go); a
// trivial "always hit" lookup is enough to exercise the cache-key
// branch this test covers.
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
for _, host := range []string{"svc-a.example", "svc-b.example"} {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil)
r.Host = host
r.RemoteAddr = "100.64.0.10:55555"
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
config, _ := mw.getDomainConfig(host)
handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
require.True(t, handled, "host %s should forward", host)
@@ -314,9 +336,17 @@ func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) {
w.WriteHeader(http.StatusOK)
}))
// Per-account inbound listener attaches WithTunnelLookup; without it
// forwardWithTunnelPeer refuses to take the fast-path. Mirror the
// real flow so this test exercises the post-gating success branch.
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.10:55555"
req = req.WithContext(WithTunnelLookup(req.Context(), lookup))
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

View File

@@ -131,7 +131,7 @@ func (h *Handler) SetCertStatus(cs certStatus) {
// SetInboundProvider wires per-account inbound listener observability.
// Pass nil (or skip the call) to keep the inbound section out of debug
// responses on proxies that don't run --private-inbound.
// responses on proxies that don't run --private.
func (h *Handler) SetInboundProvider(p InboundProvider) {
h.inbound = p
}

View File

@@ -66,6 +66,22 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Loop guard for private services: a peer that hosts the target
// dialing its own service URL would round-trip its own traffic
// through the proxy and back over WG to itself. Refuse the request
// with 421 (Misdirected Request) so the caller sees an explicit
// error instead of silently doubling tunnel traffic.
if p.isSelfTargetLoop(r, result.target.URL) {
if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginNoRoute)
}
requestID := getRequestID(r)
web.ServeErrorPage(w, r, http.StatusMisdirectedRequest, "Loop Detected",
"This peer is the target of the requested service. Reach the backend directly instead of dialing the public service URL from the same machine.",
requestID, web.ErrorStatus{Proxy: true, Destination: false})
return
}
ctx := r.Context()
// Set the account ID in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, result.accountID)
@@ -107,6 +123,32 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rp.ServeHTTP(w, r.WithContext(ctx))
}
// isSelfTargetLoop reports whether an overlay-origin request is about to
// be forwarded back to the very peer that initiated it. The detection
// is intentionally narrow: it only fires when the request arrived on
// the per-account inbound (overlay) listener (so we're confident the
// source address is the caller's tunnel IP), and only when the resolved
// target host matches that tunnel IP. Catching this here returns 421 to
// the caller instead of letting the proxy round-trip its own traffic
// over WG twice.
func (p *ReverseProxy) isSelfTargetLoop(r *http.Request, target *url.URL) bool {
if target == nil {
return false
}
if !types.IsOverlayOrigin(r.Context()) {
return false
}
srcIP := extractHostIP(r.RemoteAddr)
if !srcIP.IsValid() {
return false
}
targetIP, err := netip.ParseAddr(target.Hostname())
if err != nil {
return false
}
return srcIP.Unmap() == targetIP.Unmap()
}
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
// inbound requests to target the backend service while setting security-relevant
// forwarding headers and stripping proxy authentication credentials.

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
)
@@ -1285,6 +1286,103 @@ func TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid(t *testing.T) {
"X-NetBird-Groups must not be set when every group label is rejected")
}
// nopOKTransport returns 200 for every request without dialing — used
// by the self-target-loop tests so the non-loop cases don't pay a real
// TCP-dial timeout.
type nopOKTransport struct{}
func (nopOKTransport) RoundTrip(*http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody, Header: http.Header{}}, nil
}
// TestServeHTTP_SelfTargetLoopReturns421 covers the loop guard for
// private services: when a peer dials a service whose only target is
// the peer itself, the proxy must refuse with 421 (Misdirected
// Request) rather than round-tripping the request back over WG to
// the same peer.
func TestServeHTTP_SelfTargetLoopReturns421(t *testing.T) {
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
rp.AddMapping(Mapping{
ID: "svc-1",
AccountID: "acct-1",
Host: "private.svc",
Paths: map[string]*PathTarget{
"/": {
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
},
},
})
req := httptest.NewRequest(http.MethodGet, "http://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.5:55555"
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
rec := httptest.NewRecorder()
rp.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMisdirectedRequest, rec.Code,
"a peer dialing a service whose target is itself must get 421")
}
// TestServeHTTP_SelfTargetLoop_NonOverlayRequestPassesThrough verifies
// the guard is scoped to overlay-origin requests. A public-listener
// request that happens to share a source IP with the target host must
// not be misinterpreted as a loop — the gating relies on the inbound
// marker being attached only by the per-account overlay listener.
func TestServeHTTP_SelfTargetLoop_NonOverlayRequestPassesThrough(t *testing.T) {
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
rp.AddMapping(Mapping{
ID: "svc-1",
AccountID: "acct-1",
Host: "public.svc",
Paths: map[string]*PathTarget{
"/": {
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
},
},
})
req := httptest.NewRequest(http.MethodGet, "http://public.svc/", nil)
req.Host = "public.svc"
req.RemoteAddr = "100.64.0.5:55555"
// No WithOverlayOrigin → the guard must not fire.
rec := httptest.NewRecorder()
rp.ServeHTTP(rec, req)
assert.NotEqual(t, http.StatusMisdirectedRequest, rec.Code,
"a non-overlay request with a colliding source IP must not be flagged as a loop")
}
// TestServeHTTP_SelfTargetLoop_OverlayDifferentIPPassesThrough confirms
// that overlay-origin requests with a source IP that does *not* match
// the target host are forwarded normally.
func TestServeHTTP_SelfTargetLoop_OverlayDifferentIPPassesThrough(t *testing.T) {
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
rp.AddMapping(Mapping{
ID: "svc-1",
AccountID: "acct-1",
Host: "private.svc",
Paths: map[string]*PathTarget{
"/": {
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
},
},
})
req := httptest.NewRequest(http.MethodGet, "http://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.99:55555" // different from the target
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
rec := httptest.NewRecorder()
rp.ServeHTTP(rec, req)
assert.NotEqual(t, http.StatusMisdirectedRequest, rec.Code,
"overlay request with a non-matching source IP must not be flagged as a loop")
}
// TestStampNetBirdIdentity_CapturedDataPresentButEmpty covers requests
// that carry CapturedData with no identity fields populated (e.g. the
// auth middleware ran but the request didn't authenticate). Both

View File

@@ -213,7 +213,11 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
// Use a background context, not the caller's: the management
// connection notification must land even if the request /
// stream that triggered this registration is cancelled.
// Mirrors the async runClientStartup path.
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
@@ -242,8 +246,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, entry.client)
// retry on the first request via RoundTrip. runClientStartup uses its
// own background context so the caller's request-scoped ctx can't
// cancel the inbound bring-up.
go n.runClientStartup(accountID, entry.client)
return nil
}
@@ -355,8 +361,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
}, nil
}
// runClientStartup starts the client and notifies registered services on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
// runClientStartup starts the client and notifies registered services on
// success. This function runs in a goroutine launched from AddPeer, so it
// must never inherit the caller's request-scoped context — a canceled
// request must not abort the inbound listener bring-up or the management
// status notification. The embedded client.Start gets its own bounded
// startCtx; once Start succeeds, notifyClientReady takes over with a
// fresh context.Background() (see that function for the contract).
func (n *NetBird) runClientStartup(accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -369,7 +381,17 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
return
}
// Mark client as started and collect services to notify outside the lock.
n.notifyClientReady(accountID, client)
}
// notifyClientReady marks the account's client as started, fires the
// readyHandler hook, and notifies management of the new tunnel
// connection for every registered service. It is split out of
// runClientStartup so a regression test can drive the post-Start tail
// without needing a live embedded client. The contract that the
// hooks/notifier see context.Background() — never the AddPeer caller's
// ctx — lives here.
func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Client) {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
@@ -384,8 +406,10 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
readyHandler := n.readyHandler
n.clientsMux.Unlock()
bgCtx := context.Background()
if readyHandler != nil {
state := readyHandler(ctx, accountID, client)
state := readyHandler(bgCtx, accountID, client)
n.clientsMux.Lock()
if e, ok := n.clients[accountID]; ok {
e.inbound = state
@@ -404,7 +428,7 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
return
}
for _, sn := range toNotify {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
if err := n.statusNotifier.NotifyStatus(bgCtx, accountID, sn.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": sn.key,

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -30,12 +31,15 @@ type statusCall struct {
accountID types.AccountID
serviceID types.ServiceID
connected bool
// ctx is captured so tests can assert the notifier received a
// fresh background context rather than an inherited request ctx.
ctx context.Context
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
func (m *mockStatusNotifier) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected})
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected, ctx})
return nil
}
@@ -295,8 +299,12 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second service — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
// Add second service with an already-cancelled caller context —
// should notify immediately (client is started) AND the notification
// must not inherit the cancelled ctx.
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
err = nb.AddPeer(cancelledCtx, accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
require.NoError(t, err)
calls := notifier.calls()
@@ -304,6 +312,9 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
assert.Equal(t, accountID, calls[0].accountID)
assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID)
assert.True(t, calls[0].connected)
require.NotNil(t, calls[0].ctx, "NotifyStatus must receive a context")
require.NoError(t, calls[0].ctx.Err(),
"already-started NotifyStatus must use a background ctx, not the cancelled caller ctx")
}
// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the
@@ -360,3 +371,53 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID)
assert.False(t, calls[0].connected)
}
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
// a fresh context.Background() rather than inheriting the AddPeer
// caller's request- or stream-scoped ctx. Without this, a cancelled
// caller ctx could abort the inbound listener bring-up or cause the
// management status notification to fail spuriously and leave the
// account in a half-connected state.
func TestNotifyClientReady_UsesBackgroundCtx(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("acct-async")
// Pre-populate a client entry so notifyClientReady has something
// to mark started + something to enumerate for NotifyStatus.
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{
DomainServiceKey("svc.example"): {serviceID: types.ServiceID("svc-1")},
},
}
nb.clientsMux.Unlock()
var capturedReadyCtx context.Context
nb.SetClientLifecycle(
func(ctx context.Context, _ types.AccountID, _ *embed.Client) any {
capturedReadyCtx = ctx
return nil
},
nil,
)
// Drive the post-Start path directly; a real client.Start would
// need a working management URL.
nb.notifyClientReady(accountID, nil)
require.NotNil(t, capturedReadyCtx, "readyHandler must have been invoked")
require.NoError(t, capturedReadyCtx.Err(),
"readyHandler must receive a background context, not an inherited cancelled one")
deadline, ok := capturedReadyCtx.Deadline()
assert.False(t, ok, "readyHandler ctx must have no deadline (background); got %v", deadline)
calls := notifier.calls()
require.Len(t, calls, 1, "NotifyStatus must be invoked once per registered service")
require.NotNil(t, calls[0].ctx, "NotifyStatus must receive a context")
require.NoError(t, calls[0].ctx.Err(),
"NotifyStatus must receive a background context, not an inherited cancelled one")
}

View File

@@ -1781,11 +1781,14 @@ func TestRouter_PlainHTTP_RoutesToPlainChannel(t *testing.T) {
}
}()
tlsListener, ok := router.HTTPListener().(*chanListener)
require.True(t, ok, "router.HTTPListener() must be the test's chanListener; the test relies on observing its channel directly")
select {
case conn := <-acceptDone:
require.NotNil(t, conn)
_ = conn.Close()
case <-router.HTTPListener().(*chanListener).ch:
case <-tlsListener.ch:
t.Fatal("plain HTTP request leaked into TLS channel")
case <-time.After(3 * time.Second):
t.Fatal("plain HTTP connection never reached plain channel")

View File

@@ -20,14 +20,17 @@ import (
type Config struct {
// ListenAddr is the TCP address the main listener binds. Required.
ListenAddr string
// ID identifies this proxy instance to management. Empty value lets
// New generate a timestamped default.
// ID identifies this proxy instance to management. Empty values are
// replaced with a timestamped default at Server.Start time (see
// initDefaults), not in New.
ID string
// Logger is the logrus logger used everywhere. Empty value falls back
// to log.StandardLogger().
// Logger is the logrus logger used everywhere. Empty values fall
// back to log.StandardLogger() at Server.Start time (see
// initDefaults), not in New.
Logger *log.Logger
// Version is the build version string reported to management. Empty
// becomes "dev".
// values are replaced with "dev" at Server.Start time (see
// initDefaults), not in New.
Version string
// ProxyURL is the public address operators use to reach this proxy.
ProxyURL string

View File

@@ -281,7 +281,7 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac
}
// inboundListenerProto resolves the per-account inbound listener state for
// the SendStatusUpdate payload. Returns nil when --private-inbound is off
// the SendStatusUpdate payload. Returns nil when --private is off
// or the account has no live listener so management treats the field as
// absent.
func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener {
@@ -528,7 +528,7 @@ func (s *Server) initManagementClient() error {
}
// initNetBirdClient builds the multi-tenant embedded NetBird client used
// for outbound RoundTripping and (when --private-inbound is on) per-account
// for outbound RoundTripping and (when --private is on) per-account
// inbound listeners.
func (s *Server) initNetBirdClient() {
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{

View File

@@ -2,6 +2,7 @@ package rest
import (
"context"
"errors"
"net/url"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -33,6 +34,12 @@ func (a *ReverseProxyClustersAPI) List(ctx context.Context) ([]api.ProxyCluster,
// NetBird cannot be deleted via this endpoint; the server returns 404 / 400
// for cluster addresses the account does not own.
func (a *ReverseProxyClustersAPI) Delete(ctx context.Context, clusterAddress string) error {
// Guard against the empty input: url.PathEscape("") returns "" which
// would collapse the request URL onto the collection endpoint and
// silently delete nothing (or 405 depending on routing).
if clusterAddress == "" {
return errors.New("clusterAddress is required")
}
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/clusters/"+url.PathEscape(clusterAddress), nil, nil)
if err != nil {
return err

View File

@@ -88,3 +88,17 @@ func TestReverseProxyClusters_Delete_Err(t *testing.T) {
assert.Error(t, err)
})
}
// TestReverseProxyClusters_Delete_EmptyAddress guards against an empty
// clusterAddress reaching the wire — that would collapse the URL onto
// the collection endpoint instead of a specific cluster. The client
// must short-circuit with a typed error before any request is issued.
func TestReverseProxyClusters_Delete_EmptyAddress(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/clusters/", func(http.ResponseWriter, *http.Request) {
t.Fatal("empty clusterAddress must be rejected client-side; no request should reach the server")
})
err := c.ReverseProxyClusters.Delete(context.Background(), "")
assert.Error(t, err, "empty clusterAddress must surface as an error")
})
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"net/url"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -61,6 +62,12 @@ func (a *ReverseProxyTokensAPI) Create(ctx context.Context, request api.ProxyTok
// credentials existed; the plain secret can no longer authenticate any
// new proxy registration.
func (a *ReverseProxyTokensAPI) Delete(ctx context.Context, tokenID string) error {
// Guard against the empty input: url.PathEscape("") returns "" which
// would collapse the request URL onto the collection endpoint and
// silently delete nothing (or 405 depending on routing).
if tokenID == "" {
return errors.New("tokenID is required")
}
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/proxy-tokens/"+url.PathEscape(tokenID), nil, nil)
if err != nil {
return err

View File

@@ -129,3 +129,16 @@ func TestReverseProxyTokens_Delete_Err(t *testing.T) {
assert.Error(t, err)
})
}
// TestReverseProxyTokens_Delete_EmptyID guards against an empty tokenID
// reaching the wire — url.PathEscape("") would collapse the URL onto
// the collection endpoint.
func TestReverseProxyTokens_Delete_EmptyID(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/proxy-tokens/", func(http.ResponseWriter, *http.Request) {
t.Fatal("empty tokenID must be rejected client-side; no request should reach the server")
})
err := c.ReverseProxyTokens.Delete(context.Background(), "")
assert.Error(t, err, "empty tokenID must surface as an error")
})
}

View File

@@ -3086,6 +3086,24 @@ components:
- enabled
- auth
- meta
allOf:
# When private=true, access_groups must be present and non-empty,
# and the service mode must be "http". The bearer-auth mutex is
# enforced at the service-validation layer
# (validatePrivateRequirements) because it sits in a nested
# ServiceAuthConfig and isn't cleanly expressible here.
- if:
required: [private]
properties:
private:
const: true
then:
required: [access_groups]
properties:
access_groups:
minItems: 1
mode:
const: http
ServiceMeta:
type: object
properties:
@@ -3173,6 +3191,23 @@ components:
- name
- domain
- enabled
allOf:
# Mirror of the Service conditional: when private=true the
# request must carry a non-empty access_groups list and the
# mode must be "http". The bearer-auth mutex is enforced at the
# service-validation layer (validatePrivateRequirements).
- if:
required: [private]
properties:
private:
const: true
then:
required: [access_groups]
properties:
access_groups:
minItems: 1
mode:
const: http
ServiceTargetOptions:
type: object
properties:

View File

@@ -237,7 +237,7 @@ message SendStatusUpdateRequest {
bool certificate_issued = 4;
optional string error_message = 5;
// Per-account inbound listener state for the account that owns
// service_id. Populated only when --private-inbound is enabled and the
// service_id. Populated only when --private is enabled and the
// embedded client for the account is up. Field numbers >=50 reserved
// for observability extensions.
optional ProxyInboundListener inbound_listener = 50;