mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Use a 1:1 mapping of netbird client to netbird account
- Add debug endpoint for monitoring netbird clients - Add types package with AccountID type - Refactor netbird roundtrip to key clients by AccountID - Multiple domains can share the same client per account - Add status notifier for tunnel connection updates - Add OIDC flags to CLI - Add tests for netbird client management
This commit is contained in:
@@ -3,6 +3,8 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type requestContextKey string
|
||||
@@ -18,7 +20,7 @@ const (
|
||||
type CapturedData struct {
|
||||
mu sync.RWMutex
|
||||
ServiceId string
|
||||
AccountId string
|
||||
AccountId types.AccountID
|
||||
}
|
||||
|
||||
// SetServiceId safely sets the service ID
|
||||
@@ -36,14 +38,14 @@ func (c *CapturedData) GetServiceId() string {
|
||||
}
|
||||
|
||||
// SetAccountId safely sets the account ID
|
||||
func (c *CapturedData) SetAccountId(accountId string) {
|
||||
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.AccountId = accountId
|
||||
}
|
||||
|
||||
// GetAccountId safely gets the account ID
|
||||
func (c *CapturedData) GetAccountId() string {
|
||||
func (c *CapturedData) GetAccountId() types.AccountID {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.AccountId
|
||||
@@ -76,13 +78,13 @@ func ServiceIdFromContext(ctx context.Context) string {
|
||||
}
|
||||
return serviceId
|
||||
}
|
||||
func withAccountId(ctx context.Context, accountId string) context.Context {
|
||||
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIdKey, accountId)
|
||||
}
|
||||
|
||||
func AccountIdFromContext(ctx context.Context) string {
|
||||
func AccountIdFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIdKey)
|
||||
accountId, ok := v.(string)
|
||||
accountId, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
)
|
||||
|
||||
type ReverseProxy struct {
|
||||
@@ -36,8 +38,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Set the serviceId in the context for later retrieval.
|
||||
ctx := withServiceId(r.Context(), serviceId)
|
||||
// Set the accountId in the context for later retrieval.
|
||||
// Set the accountId in the context for later retrieval (for middleware).
|
||||
ctx = withAccountId(ctx, accountID)
|
||||
// Set the accountId in the context for the roundtripper to use.
|
||||
ctx = roundtrip.WithAccountID(ctx, accountID)
|
||||
|
||||
// Also populate captured data if it exists (allows middleware to read after handler completes).
|
||||
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
|
||||
|
||||
@@ -6,16 +6,18 @@ import (
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type Mapping struct {
|
||||
ID string
|
||||
AccountID string
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, string, bool) {
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, types.AccountID, bool) {
|
||||
p.mappingsMux.RLock()
|
||||
if p.mappings == nil {
|
||||
p.mappingsMux.RUnlock()
|
||||
@@ -27,10 +29,12 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
|
||||
}
|
||||
defer p.mappingsMux.RUnlock()
|
||||
|
||||
host, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
host = req.Host
|
||||
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
|
||||
host := req.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
m, exists := p.mappings[host]
|
||||
if !exists {
|
||||
return nil, "", "", false
|
||||
|
||||
Reference in New Issue
Block a user