Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy

# Conflicts:
#	management/internals/modules/reverseproxy/reverseproxy.go
#	management/internals/server/boot.go
#	management/internals/shared/grpc/proxy.go
#	proxy/internal/auth/middleware.go
#	shared/management/proto/proxy_service.pb.go
#	shared/management/proto/proxy_service.proto
#	shared/management/proto/proxy_service_grpc.pb.go
This commit is contained in:
Alisdair MacLeod
2026-02-04 11:56:04 +00:00
81 changed files with 8413 additions and 458 deletions

View File

@@ -31,20 +31,24 @@ import (
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/debug"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/embeddedroots"
)
type Server struct {
mgmtConn *grpc.ClientConn
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
// Mostly used for debugging on management.
startTime time.Time
@@ -62,6 +66,38 @@ type Server struct {
OIDCClientSecret string
OIDCEndpoint string
OIDCScopes []string
// DebugEndpointEnabled enables the debug HTTP endpoint.
DebugEndpointEnabled bool
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
DebugEndpointAddress string
}
// NotifyStatus sends a status update to management about tunnel connectivity
func (s *Server) NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error {
status := proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED
if connected {
status = proto.ProxyStatus_PROXY_STATUS_ACTIVE
}
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
ReverseProxyId: reverseProxyID,
AccountId: accountID,
Status: status,
CertificateIssued: false,
})
return err
}
// NotifyCertificateIssued sends a notification to management that a certificate was issued
func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, reverseProxyID, domain string) error {
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
ReverseProxyId: reverseProxyID,
AccountId: accountID,
Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE,
CertificateIssued: true,
})
return err
}
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
@@ -105,7 +141,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
"gRPC_address": mgmtURL.Host,
"TLS_enabled": mgmtURL.Scheme == "https",
}).Debug("starting management gRPC client")
s.mgmtConn, err = grpc.NewClient(mgmtURL.Host,
mgmtConn, err := grpc.NewClient(mgmtURL.Host,
grpc.WithTransportCredentials(creds),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 20 * time.Second,
@@ -116,18 +152,18 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
if err != nil {
return fmt.Errorf("could not create management connection: %w", err)
}
mgmtClient := proto.NewProxyServiceClient(s.mgmtConn)
go s.newManagementMappingWorker(ctx, mgmtClient)
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
go s.newManagementMappingWorker(ctx, s.mgmtClient)
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.Logger)
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.Logger, s)
// When generating ACME certificates, start a challenge server.
tlsConfig := &tls.Config{}
if s.GenerateACMECertificates {
s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory)
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s)
s.http = &http.Server{
Addr: s.ACMEChallengeAddress,
Handler: s.acme.HTTPHandler(nil),
@@ -175,7 +211,35 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.auth = auth.NewMiddleware()
// Configure Access logs to management server.
accessLog := accesslog.NewLogger(mgmtClient, s.Logger)
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger)
if s.DebugEndpointEnabled {
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
debugHandler := debug.NewHandler(s.netbird, s.Logger)
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
}
go func() {
s.Logger.WithField("address", debugAddr).Info("starting debug endpoint")
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
}()
defer func() {
if err := s.debug.Close(); err != nil {
s.Logger.Debugf("debug endpoint close: %v", err)
}
}()
}
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.netbird.StopAll(stopCtx); err != nil {
s.Logger.Warnf("failed to stop all netbird clients: %v", err)
}
}()
// Finally, start the reverse proxy.
s.https = &http.Server{
@@ -279,11 +343,15 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
}
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
if err := s.netbird.AddPeer(ctx, mapping.GetDomain(), mapping.GetSetupKey()); err != nil {
return fmt.Errorf("create peer for domain %q: %w", mapping.GetDomain(), err)
d := domain.Domain(mapping.GetDomain())
accountID := types.AccountID(mapping.GetAccountId())
reverseProxyID := mapping.GetId()
if err := s.netbird.AddPeer(ctx, accountID, d, mapping.GetSetupKey(), reverseProxyID); err != nil {
return fmt.Errorf("create peer for domain %q: %w", d, err)
}
if s.acme != nil {
s.acme.AddDomain(mapping.GetDomain())
s.acme.AddDomain(string(d), string(accountID), reverseProxyID)
}
// Pass the mapping through to the update function to avoid duplicating the
@@ -299,13 +367,12 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
// the auth and proxy mappings.
// Note: this does require the management server to always send a
// full mapping rather than deltas during a modification.
mgmtClient := proto.NewProxyServiceClient(s.mgmtConn)
var schemes []auth.Scheme
if mapping.GetAuth().GetPassword() {
schemes = append(schemes, auth.NewPassword(mgmtClient, mapping.GetId(), mapping.GetAccountId()))
schemes = append(schemes, auth.NewPassword(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
}
if mapping.GetAuth().GetPin() {
schemes = append(schemes, auth.NewPin(mgmtClient, mapping.GetId(), mapping.GetAccountId()))
schemes = append(schemes, auth.NewPin(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
}
if mapping.GetAuth().GetOidc() != nil {
oidc := mapping.GetAuth().GetOidc()
@@ -317,17 +384,20 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
}))
}
if mapping.GetAuth().GetLink() {
schemes = append(schemes, auth.NewLink(mgmtClient, mapping.GetId(), mapping.GetAccountId()))
schemes = append(schemes, auth.NewLink(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
}
s.auth.AddDomain(mapping.GetDomain(), schemes)
s.proxy.AddMapping(s.protoToMapping(mapping))
}
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
if err := s.netbird.RemovePeer(ctx, mapping.GetDomain()); err != nil {
d := domain.Domain(mapping.GetDomain())
accountID := types.AccountID(mapping.GetAccountId())
if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil {
s.Logger.WithFields(log.Fields{
"domain": mapping.GetDomain(),
"error": err,
"account_id": accountID,
"domain": d,
"error": err,
}).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist")
}
if s.acme != nil {
@@ -356,8 +426,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
}
return proxy.Mapping{
ID: mapping.GetId(),
AccountID: mapping.AccountId,
AccountID: types.AccountID(mapping.GetAccountId()),
Host: mapping.GetDomain(),
Paths: paths,
}
}
// debugEndpointAddr returns the address for the debug endpoint.
// If addr is empty, it defaults to localhost:8444 for security.
func debugEndpointAddr(addr string) string {
if addr == "" {
return "localhost:8444"
}
return addr
}